Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Segmentation models #4646

Merged
merged 16 commits into from
Oct 19, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 18, 2021

Fixes #4676

This PR refactors the entire segmentation package to align with the conventions followed by classification, detection and quantization.

Here are the details:

  • The public model builders located at models.segmentation.segmentation now move to their respective model packages. For example, fcn_resnet50() is now defined in fcn.py.
  • The private builder methods of each Model are refactored, so that they can be reused on prototype with minimal copy-pasting.
  • We expose publicly pretrained_backbone instead via kwargs.

Accuracy remains the same for all pre-trained models. No modifications on tests required because the changes are fully-BC. This PR enables the work on #4611.

@datumbox datumbox force-pushed the models/segmentation_refactoring branch from 679b871 to c4c5ac4 Compare October 18, 2021 17:50
@datumbox datumbox changed the title [WIP] Refactor Segmentation models Refactor Segmentation models Oct 18, 2021
@datumbox datumbox marked this pull request as ready for review October 18, 2021 18:32
@datumbox datumbox force-pushed the models/segmentation_refactoring branch from 150ab32 to de1d2ad Compare October 18, 2021 19:02
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highlighting few important implementation details:



def _deeplabv3_resnet(
backbone: resnet.ResNet,
Copy link
Contributor Author

@datumbox datumbox Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We make all private model builders accept the pre-initialized backbones instead of passing the backbone_name. This allows us to reuse the methods on the multi-pretrained weights project.


aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = DeepLabHead(2048, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall we simplify the code by splitting the previous massive _segm_model() method.

torchvision/models/segmentation/deeplabv3.py Outdated Show resolved Hide resolved
progress: bool = True,
num_classes: int = 21,
aux_loss: bool = False,
pretrained_backbone: bool = True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of **kwargs we directly expose the pretrained_backbone parameter publicly.


from torch import nn
from . import * # noqa: F401, F403
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import all methods, classes, etc just in-case someone was using the package:
torchvision.models.segmentation.segmentation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind just adding this as a comment so we don't "clean" it later?

_load_weights(model, "lraspp", backbone_name, progress)

return model
warnings.warn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw a deprecation warning to use the parent package. Aka torchvision.models.segmentation.

@datumbox datumbox force-pushed the models/segmentation_refactoring branch from 4e19c1d to ede7980 Compare October 18, 2021 19:38
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and for writing comments about the important parts.
This looks great, I just have one concern regarding the change of default for aux_loss, LMK what you think


from torch import nn
from . import * # noqa: F401, F403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind just adding this as a comment so we don't "clean" it later?

torchvision/models/segmentation/deeplabv3.py Outdated Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox ! LGTM modulo some quick verification of the training perf as you mentioned already

(and the type checks apparently)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for cleaning this up!

Comment on lines +178 to +180
if pretrained:
arch = "deeplabv3_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I suppose you've left this here (instead of putting it in _deeplabv3_resnet because it will be more aligned with your changes to the new weights?

Same for the backbone retrieval code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. I was going back and forth about this. If I were to put it in the builder methods, I would have to copy-paste the whole thing...

I think an additional final clean up wold be necessary prior moving the prototype work to main and there we would be able to move things around. This is a great candidate for such clean up.

@datumbox
Copy link
Contributor Author

Thanks for the reviews. I confirmed that all models maintain the right model validation statistics when compared to the main branch:

torchrun --nproc_per_node=2 train.py --test-only --pretrained --model fcn_resnet50
global correct: 91.4
mean IoU: 60.5

torchrun --nproc_per_node=2 train.py --test-only --pretrained --model fcn_resnet101
global correct: 91.9
mean IoU: 63.7

torchrun --nproc_per_node=2 train.py --test-only --pretrained --model deeplabv3_resnet50
global correct: 92.4
mean IoU: 66.4

torchrun --nproc_per_node=2 train.py --test-only --pretrained --model deeplabv3_resnet101
global correct: 92.4
mean IoU: 67.4

torchrun --nproc_per_node=2 train.py --test-only --pretrained --model deeplabv3_mobilenet_v3_large
global correct: 91.2
mean IoU: 60.3

torchrun --nproc_per_node=2 train.py --test-only --pretrained --model lraspp_mobilenet_v3_large
global correct: 91.2
mean IoU: 57.9

I'll merge once the CI passes.

@datumbox datumbox merged commit a1d6b31 into pytorch:main Oct 19, 2021
@datumbox datumbox deleted the models/segmentation_refactoring branch October 19, 2021 12:40
facebook-github-bot pushed a commit that referenced this pull request Oct 25, 2021
Summary:
* Move FCN methods to itsown package.

* Fix lint.

* Move LRASPP methods to their own package.

* Move DeepLabV3 methods to their own package.

* Adding deprecation warning for torchvision.models.segmentation.segmentation.

* Refactoring deeplab.

* Setting aux default to false.

* Fixing imports.

* Passing backbones instead of backbone names to builders.

* Fixing mypy

* Addressing review comments.

* Correcting typing.

* Restoring special handling for references.

Reviewed By: datumbox

Differential Revision: D31898217

fbshipit-source-id: 187bb6cc50ea14f121a907aaac4770f17696d6af
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* Move FCN methods to itsown package.

* Fix lint.

* Move LRASPP methods to their own package.

* Move DeepLabV3 methods to their own package.

* Adding deprecation warning for torchvision.models.segmentation.segmentation.

* Refactoring deeplab.

* Setting aux default to false.

* Fixing imports.

* Passing backbones instead of backbone names to builders.

* Fixing mypy

* Addressing review comments.

* Correcting typing.

* Restoring special handling for references.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor Models builders to make them reusable on Prototype
4 participants