Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetinaNet object detection (take 2) #2784

Merged
merged 41 commits into from
Oct 13, 2020
Merged

RetinaNet object detection (take 2) #2784

merged 41 commits into from
Oct 13, 2020

Conversation

fmassa
Copy link
Member

@fmassa fmassa commented Oct 10, 2020

This is entirely based on top of the great work from @hgaiser in #1697

I'm creating a new PR because there are some minor things that could be fixed in that PR but I don't have rights to push to the PR, so in order to move faster I'm creating a new PR but all the history is kept here.

Here is the mAP for the uploaded model:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.364
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.558
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.383
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.193
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.490
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.315
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.506
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.558
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.386
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.595
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.699

Hans Gaiser and others added 4 commits October 9, 2020 14:09
@@ -205,7 +205,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)

# compute the loss
losses.append(det_utils.smooth_l1_loss(
losses.append(torch.nn.functional.l1_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch it to regular l1 loss?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives a 1 mAP improvement on the models, and we were still lagging a bit behind on the mAP compared to detectron2 (which has now adopted the L1 loss by default as well, see facebookresearch/detectron2@b0e2687)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice, wasn't expecting such a difference. What is the mAP that you are getting now? Also, wow, 37.4 mAP. That's impressive.

Copy link
Member Author

@fmassa fmassa Oct 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the mAP scores for the model I'll be uploading, with L1 loss:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.364
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.558
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.383
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.193
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.490
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.315
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.506
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.558
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.386
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.595
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.699

It is still lagging behind compared to D2, but it's a bit closer. We might revisit the models in torchvision in the near future to improve mAP with latest training tricks.

@@ -100,7 +100,8 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
# no matched_idxs means there were no annotations in this image
if False:#matched_idxs_per_image.numel() == 0:
# TODO: enable support for images without annotations that works on distributed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the problem with images without annotations on distributed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be cases where in one GPU no images had annotations, while in the other there were. In this case, there will be part of the computation graph which would not be executed in one GPU, leading to synchronization issues (and even deadlocks).

So for now I'm disabling support for this to move forward, and we will enable it in a later PR (after the release)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we are not passing find_unused_params to DDP, which you were probably using (because I had to change something else in the resnet_fpn_backbone for it to work). With find_unused_params=True, this might not be an issue, but I prefer to be on the safer side and let all computation graphs to be the same on every GPU

@fmassa fmassa merged commit 5bb81c8 into pytorch:master Oct 13, 2020
@fmassa fmassa deleted the retinanet branch October 13, 2020 11:02
@liminghu
Copy link

This is entirely based on top of the great work from @hgaiser in #1697

I'm creating a new PR because there are some minor things that could be fixed in that PR but I don't have rights to push to the PR, so in order to move faster I'm creating a new PR but all the history is kept here.

Here is the mAP for the uploaded model:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.364
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.558
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.383
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.193
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.490
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.315
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.506
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.558
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.386
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.595
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.699

Thanks a lot. Any ttutorial on how to train the Retinanet on COCO database or other image sets?

@fmassa
Copy link
Member Author

fmassa commented Nov 12, 2020

@liminghu We have training scripts for all detection models in https://github.com/pytorch/vision/tree/master/references/detection and a finetuning tutorial (for mask rcnn) in https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

bryant1410 pushed a commit to bryant1410/vision-1 that referenced this pull request Nov 22, 2020
* Add rough implementation of RetinaNet.

* Move AnchorGenerator to a seperate file.

* Move box similarity to Matcher.

* Expose extra blocks in FPN.

* Expose retinanet in __init__.py.

* Use P6 and P7 in FPN for retinanet.

* Use parameters from retinanet for anchor generation.

* General fixes for retinanet model.

* Implement loss for retinanet heads.

* Output reshaped outputs from retinanet heads.

* Add postprocessing of detections.

* Small fixes.

* Remove unused argument.

* Remove python2 invocation of super.

* Add postprocessing for additional outputs.

* Add missing import of ImageList.

* Remove redundant import.

* Simplify class correction.

* Fix pylint warnings.

* Remove the label adjustment for background class.

* Set default score threshold to 0.05.

* Add weight initialization for regression layer.

* Allow training on images with no annotations.

* Use smooth_l1_loss with beta value.

* Add more typehints for TorchScript conversions.

* Fix linting issues.

* Fix type hints in postprocess_detections.

* Fix type annotations for TorchScript.

* Fix inconsistency with matched_idxs.

* Add retinanet model test.

* Add missing JIT annotations.

* Remove redundant model construction

Make tests pass

* Fix bugs during training on newer PyTorch and unused params in DDP

Needs cleanup and to add back support for images with no annotations

* Cleanup resnet_fpn_backbone

* Use L1 loss for regression

Gives 1mAP improvement over smooth l1

* Disable support for images with no annotations

Need to fix distributed first

* Fix retinanet tests

Need to deduplicate those box checks

* Fix Lint

* Add pretrained model

* Add training info for retinanet

Co-authored-by: Hans Gaiser <[email protected]>
Co-authored-by: Hans Gaiser <[email protected]>
Co-authored-by: Hans Gaiser <[email protected]>
vfdev-5 pushed a commit to Quansight/vision that referenced this pull request Dec 4, 2020
* Add rough implementation of RetinaNet.

* Move AnchorGenerator to a seperate file.

* Move box similarity to Matcher.

* Expose extra blocks in FPN.

* Expose retinanet in __init__.py.

* Use P6 and P7 in FPN for retinanet.

* Use parameters from retinanet for anchor generation.

* General fixes for retinanet model.

* Implement loss for retinanet heads.

* Output reshaped outputs from retinanet heads.

* Add postprocessing of detections.

* Small fixes.

* Remove unused argument.

* Remove python2 invocation of super.

* Add postprocessing for additional outputs.

* Add missing import of ImageList.

* Remove redundant import.

* Simplify class correction.

* Fix pylint warnings.

* Remove the label adjustment for background class.

* Set default score threshold to 0.05.

* Add weight initialization for regression layer.

* Allow training on images with no annotations.

* Use smooth_l1_loss with beta value.

* Add more typehints for TorchScript conversions.

* Fix linting issues.

* Fix type hints in postprocess_detections.

* Fix type annotations for TorchScript.

* Fix inconsistency with matched_idxs.

* Add retinanet model test.

* Add missing JIT annotations.

* Remove redundant model construction

Make tests pass

* Fix bugs during training on newer PyTorch and unused params in DDP

Needs cleanup and to add back support for images with no annotations

* Cleanup resnet_fpn_backbone

* Use L1 loss for regression

Gives 1mAP improvement over smooth l1

* Disable support for images with no annotations

Need to fix distributed first

* Fix retinanet tests

Need to deduplicate those box checks

* Fix Lint

* Add pretrained model

* Add training info for retinanet

Co-authored-by: Hans Gaiser <[email protected]>
Co-authored-by: Hans Gaiser <[email protected]>
Co-authored-by: Hans Gaiser <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants