-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Segmentation models #4646
Refactor Segmentation models #4646
Conversation
679b871
to
c4c5ac4
Compare
150ab32
to
de1d2ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Highlighting few important implementation details:
|
||
|
||
def _deeplabv3_resnet( | ||
backbone: resnet.ResNet, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We make all private model builders accept the pre-initialized backbones instead of passing the backbone_name
. This allows us to reuse the methods on the multi-pretrained weights project.
|
||
aux_classifier = FCNHead(1024, num_classes) if aux else None | ||
classifier = DeepLabHead(2048, num_classes) | ||
return DeepLabV3(backbone, classifier, aux_classifier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall we simplify the code by splitting the previous massive _segm_model()
method.
progress: bool = True, | ||
num_classes: int = 21, | ||
aux_loss: bool = False, | ||
pretrained_backbone: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of **kwargs
we directly expose the pretrained_backbone
parameter publicly.
|
||
from torch import nn | ||
from . import * # noqa: F401, F403 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import all methods, classes, etc just in-case someone was using the package:
torchvision.models.segmentation.segmentation
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind just adding this as a comment so we don't "clean" it later?
_load_weights(model, "lraspp", backbone_name, progress) | ||
|
||
return model | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Throw a deprecation warning to use the parent package. Aka torchvision.models.segmentation
.
4e19c1d
to
ede7980
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR and for writing comments about the important parts.
This looks great, I just have one concern regarding the change of default for aux_loss
, LMK what you think
|
||
from torch import nn | ||
from . import * # noqa: F401, F403 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind just adding this as a comment so we don't "clean" it later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @datumbox ! LGTM modulo some quick verification of the training perf as you mentioned already
(and the type checks apparently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for cleaning this up!
if pretrained: | ||
arch = "deeplabv3_resnet50_coco" | ||
_load_weights(arch, model, model_urls.get(arch, None), progress) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I suppose you've left this here (instead of putting it in _deeplabv3_resnet
because it will be more aligned with your changes to the new weights?
Same for the backbone retrieval code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly. I was going back and forth about this. If I were to put it in the builder methods, I would have to copy-paste the whole thing...
I think an additional final clean up wold be necessary prior moving the prototype work to main and there we would be able to move things around. This is a great candidate for such clean up.
Thanks for the reviews. I confirmed that all models maintain the right model validation statistics when compared to the main branch:
I'll merge once the CI passes. |
Summary: * Move FCN methods to itsown package. * Fix lint. * Move LRASPP methods to their own package. * Move DeepLabV3 methods to their own package. * Adding deprecation warning for torchvision.models.segmentation.segmentation. * Refactoring deeplab. * Setting aux default to false. * Fixing imports. * Passing backbones instead of backbone names to builders. * Fixing mypy * Addressing review comments. * Correcting typing. * Restoring special handling for references. Reviewed By: datumbox Differential Revision: D31898217 fbshipit-source-id: 187bb6cc50ea14f121a907aaac4770f17696d6af
* Move FCN methods to itsown package. * Fix lint. * Move LRASPP methods to their own package. * Move DeepLabV3 methods to their own package. * Adding deprecation warning for torchvision.models.segmentation.segmentation. * Refactoring deeplab. * Setting aux default to false. * Fixing imports. * Passing backbones instead of backbone names to builders. * Fixing mypy * Addressing review comments. * Correcting typing. * Restoring special handling for references.
Fixes #4676
This PR refactors the entire
segmentation
package to align with the conventions followed byclassification
,detection
andquantization
.Here are the details:
models.segmentation.segmentation
now move to their respective model packages. For example,fcn_resnet50()
is now defined infcn.py
.prototype
with minimal copy-pasting.pretrained_backbone
instead viakwargs
.Accuracy remains the same for all pre-trained models. No modifications on tests required because the changes are fully-BC. This PR enables the work on #4611.