Skip to content

Commit

Permalink
Overwriting FrozenBN eps=0.0 if pretrained=True for detection models. (
Browse files Browse the repository at this point in the history
…pytorch#2940)

* Overwriting FrozenBN eps=0.0 if pretrained=True for detection models.

* Moving the method to detection utils and adding comments.
  • Loading branch information
datumbox authored and vfdev-5 committed Dec 4, 2020
1 parent 25689e8 commit c5b6736
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 5 deletions.
6 changes: 2 additions & 4 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest
import random

from torchvision.ops.misc import FrozenBatchNorm2d
from torchvision.models.detection._utils import overwrite_eps


def set_rng_seed(seed):
Expand Down Expand Up @@ -151,9 +151,7 @@ def _test_detection_model(self, name, dev):
kwargs["score_thresh"] = 0.013
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
if "keypointrcnn" in name or "retinanet" in name:
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = 0
overwrite_eps(model, 0.0)
model.eval().to(device=dev)
input_shape = (3, 300, 300)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
Expand Down
21 changes: 20 additions & 1 deletion torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
import torchvision

from torchvision.ops.misc import FrozenBatchNorm2d


class BalancedPositiveNegativeSampler(object):
Expand Down Expand Up @@ -349,3 +350,21 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru
if size_average:
return loss.mean()
return loss.sum()


def overwrite_eps(model, eps):
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
This is necessary to address the BC-breaking change introduced
by the bug-fix at pytorch/vision#2933. The overwrite is applied
only when the pretrained weights are loaded to maintain compatibility
with previous versions.
Arguments:
model (nn.Module): The model on which we perform the overwrite.
eps (float): The new value of eps.
"""
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = eps
2 changes: 2 additions & 0 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign

from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url

from .anchor_utils import AnchorGenerator
Expand Down Expand Up @@ -361,4 +362,5 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
2 changes: 2 additions & 0 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from torchvision.ops import MultiScaleRoIAlign

from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url

from .faster_rcnn import FasterRCNN
Expand Down Expand Up @@ -332,4 +333,5 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls[key],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
2 changes: 2 additions & 0 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign

from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url

from .faster_rcnn import FasterRCNN
Expand Down Expand Up @@ -328,4 +329,5 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
2 changes: 2 additions & 0 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple

from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url

from . import _utils as det_utils
Expand Down Expand Up @@ -628,4 +629,5 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model

0 comments on commit c5b6736

Please sign in to comment.