Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-pretrained weight support - FasterRCNN ResNet50 #4613

Merged
merged 6 commits into from
Oct 15, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 14, 2021

Resolves #4671

Example usage:

from torchvision.io.image import read_image
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision import prototype as P


img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Initialize model
weights = P.models.detection.FasterRCNNResNet50FPNWeights.Coco_RefV1
model = P.models.detection.fasterrcnn_resnet50_fpn(weights=weights)
model.eval()

# Initialize inference transforms
preprocess = weights.transforms()

# Apply inference preprocessing transforms
batch = [preprocess(img)[0]]
prediction = model(batch)[0]

# Use meta to get labels
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()

cc @datumbox @pmeier @bjuncek

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some clarification comments below:

returned_layers=None,
extra_blocks=None,
):
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
Copy link
Contributor Author

@datumbox datumbox Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I'm forced to copy the whole function just to change the pretrained to weights param. I refactored to minimize copy-pasted code.



# Allows handling of both PIL and Tensor images
class ConvertImageDtype(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the standalone transform to avoid introducing a new class here.

import warnings
from typing import Any, Optional

from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inherit as much as possible. The changes below will be moved on the existing files once we move to torchvision.


def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = ResNet50Weights.ImageNet1K_RefV1, # TODO: Should we default to None?
Copy link
Contributor Author

@datumbox datumbox Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default value is to align with the old logic, NEVERTHELESS it's not necessary for BC (due to the way I handled the pretrained param below).

Personally I don't like that we have hardcoded an OLD set of weights here. If we set the default value to None and force the user to choose, we will eliminate future BC considerations if these weights become too old and not-optimal.

This can be addressed on a follow up PR.

cc @fmassa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to set this to None, happy to review if someone has a strong opinion.

Copy link
Contributor

@prabhat00155 prabhat00155 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox! Few minor comments.

torchvision/models/detection/backbone_utils.py Outdated Show resolved Hide resolved
torchvision/prototype/models/__init__.py Show resolved Hide resolved
torchvision/prototype/models/_meta.py Show resolved Hide resolved
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

torchvision/prototype/models/detection/faster_rcnn.py Outdated Show resolved Hide resolved

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably raise an error / warning if the user modifies the num_classes and passes a weights argument. Otherwise they might silently think that we are doing magic inside

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll add this check to resnet as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this and it's a bit problematic. The num_classes parameter has a default value in all of our model builders. So to see i it was modified, we need to see if the default value was changed which can lead to messy code. An alternative approach could be to throw a warning if the num_classes != len(weights.meta["categories"]) but still overwrite it to make the life of users easier.

Because it's not clear how this should be handled, I'm going to merge the PR to unblock the work but I'm happy to discuss the policy here and update everywhere in a follow up PR.

@datumbox datumbox merged commit ec2456a into pytorch:main Oct 15, 2021
@datumbox datumbox deleted the multiweight/fasterrcnn branch October 15, 2021 12:49
@github-actions
Copy link

Hey @datumbox!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

mszhanyi pushed a commit to mszhanyi/vision that referenced this pull request Oct 19, 2021
* Adding FasterRCNN ResNet50.

* Refactoring to remove duplicate code.

* Adding typing info.

* Setting weights_backbone=None as default value.

* Overwrite eps only for specific weights.
facebook-github-bot pushed a commit that referenced this pull request Oct 19, 2021
Summary:
* Adding FasterRCNN ResNet50.

* Refactoring to remove duplicate code.

* Adding typing info.

* Setting weights_backbone=None as default value.

* Overwrite eps only for specific weights.

Reviewed By: NicolasHug

Differential Revision: D31758312

fbshipit-source-id: 714a714d897bb4b4d9da1298ad5e2606998898b9
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* Adding FasterRCNN ResNet50.

* Refactoring to remove duplicate code.

* Adding typing info.

* Setting weights_backbone=None as default value.

* Overwrite eps only for specific weights.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-pretrained weights: Add initial API and basic implementation
4 participants