-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-pretrained weight support - FasterRCNN ResNet50 #4613
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some clarification comments below:
returned_layers=None, | ||
extra_blocks=None, | ||
): | ||
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I'm forced to copy the whole function just to change the pretrained
to weights
param. I refactored to minimize copy-pasted code.
|
||
|
||
# Allows handling of both PIL and Tensor images | ||
class ConvertImageDtype(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the standalone transform to avoid introducing a new class here.
import warnings | ||
from typing import Any, Optional | ||
|
||
from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inherit as much as possible. The changes below will be moved on the existing files once we move to torchvision.
|
||
def fasterrcnn_resnet50_fpn( | ||
weights: Optional[FasterRCNNResNet50FPNWeights] = None, | ||
weights_backbone: Optional[ResNet50Weights] = ResNet50Weights.ImageNet1K_RefV1, # TODO: Should we default to None? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This default value is to align with the old logic, NEVERTHELESS it's not necessary for BC (due to the way I handled the pretrained
param below).
Personally I don't like that we have hardcoded an OLD set of weights here. If we set the default value to None and force the user to choose, we will eliminate future BC considerations if these weights become too old and not-optimal.
This can be addressed on a follow up PR.
cc @fmassa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to set this to None
, happy to review if someone has a strong opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @datumbox! Few minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
|
||
if weights is not None: | ||
weights_backbone = None | ||
num_classes = len(weights.meta["categories"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably raise an error / warning if the user modifies the num_classes
and passes a weights
argument. Otherwise they might silently think that we are doing magic inside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I'll add this check to resnet as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this and it's a bit problematic. The num_classes
parameter has a default value in all of our model builders. So to see i it was modified, we need to see if the default value was changed which can lead to messy code. An alternative approach could be to throw a warning if the num_classes
!= len(weights.meta["categories"])
but still overwrite it to make the life of users easier.
Because it's not clear how this should be handled, I'm going to merge the PR to unblock the work but I'm happy to discuss the policy here and update everywhere in a follow up PR.
Hey @datumbox! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
* Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights.
Summary: * Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights. Reviewed By: NicolasHug Differential Revision: D31758312 fbshipit-source-id: 714a714d897bb4b4d9da1298ad5e2606998898b9
* Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights.
Resolves #4671
Example usage:
cc @datumbox @pmeier @bjuncek