Skip to content

Commit

Permalink
Moving vgg's extra layers a separate class + L2 scaling.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 8, 2021
1 parent bffe4bc commit 03bc52c
Showing 1 changed file with 65 additions and 73 deletions.
138 changes: 65 additions & 73 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch
import torch.nn.functional as F

from torch import nn, Tensor
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -73,17 +76,73 @@ def forward(self, images: List[Tensor],
pass


class MultiFeatureMap(nn.Module):
class SSDFeatureExtractorVGG(nn.Module):

OUT_CHANNELS = (512, 1024, 512, 256, 256, 256)

def __init__(self, feature_maps: nn.ModuleList):
def __init__(self, backbone: nn.Module):
super().__init__()
self.feature_maps = feature_maps

# Patch ceil_mode for all maxpool layers of backbone to get the same WxH output sizes as the paper
penultimate_block_pos = ultimate_block_pos = None
for i, layer in enumerate(backbone):
if isinstance(layer, nn.MaxPool2d):
layer.ceil_mode = True
penultimate_block_pos = ultimate_block_pos
ultimate_block_pos = i

# parameters used for L2 regularization + rescaling
self.scale_weight = nn.Parameter(torch.ones(self.OUT_CHANNELS[0]) * 20)

# Multiple Feature maps - page 4, Fig 2 of SSD paper
self.block1 = nn.Sequential(
*backbone[:penultimate_block_pos] # until conv4_3
)
self.block2 = nn.Sequential(
*backbone[penultimate_block_pos:-1], # until conv5_3, skip maxpool5
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), # add modified maxpool5
nn.Conv2d(in_channels=self.OUT_CHANNELS[0],
out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=1024, out_channels=self.OUT_CHANNELS[1], kernel_size=1), # FC7
nn.ReLU(inplace=True)
)
self.block3 = nn.Sequential(
nn.Conv2d(self.OUT_CHANNELS[1], 256, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, self.OUT_CHANNELS[2], kernel_size=3, padding=1, stride=2), # conv8_2
nn.ReLU(inplace=True),
)
self.block4 = nn.Sequential(
nn.Conv2d(self.OUT_CHANNELS[2], 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, self.OUT_CHANNELS[3], kernel_size=3, padding=1, stride=2), # conv9_2
nn.ReLU(inplace=True),
)
self.block5 = nn.Sequential(
nn.Conv2d(self.OUT_CHANNELS[3], 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, self.OUT_CHANNELS[4], kernel_size=3), # conv10_2
nn.ReLU(inplace=True),
)
self.block6 = nn.Sequential(
nn.Conv2d(self.OUT_CHANNELS[4], 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, self.OUT_CHANNELS[5], kernel_size=3), # conv11_2
nn.ReLU(inplace=True),
)

def forward(self, x):
output = []
for block in self.feature_maps:
# L2 regularization + Rescaling of 1st block's feature map
x = self.block1(x)
rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
output = [rescaled]

# Calculating Feature maps for the rest blocks
for block in (self.block2, self.block3, self.block4, self.block5, self.block6):
x = block(x)
output.append(x)

return output


Expand All @@ -102,74 +161,7 @@ def _vgg_mfm_backbone(backbone_name, pretrained, trainable_layers=3):
for parameter in b.parameters():
parameter.requires_grad_(False)

# Patch ceil_mode for all maxpool layers of backbone to get the same outputs as Fig2 of SSD papers
for layer in backbone:
if isinstance(layer, nn.MaxPool2d):
layer.ceil_mode = True

# Multiple Feature map definition - page 4, Fig 2 of SSD paper
def build_feature_map_block(layers, out_channels):
block = nn.Sequential(*layers)
block.out_channels = out_channels
return block

penultimate_block_index = stage_indices[-2]
feature_maps = nn.ModuleList([
build_feature_map_block(
backbone[:penultimate_block_index], # until conv4_3
# TODO: add L2 nomarlization + scaling?
512
),
build_feature_map_block(
(
*backbone[penultimate_block_index:-1], # until conv5_3, skip last maxpool
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), # add modified maxpool5
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
nn.ReLU(inplace=True)
),
1024
),
build_feature_map_block(
(
nn.Conv2d(1024, 256, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
nn.ReLU(inplace=True),
),
512,
),
build_feature_map_block(
(
nn.Conv2d(512, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
nn.ReLU(inplace=True),
),
256,
),
build_feature_map_block(
(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3), # conv10_2
nn.ReLU(inplace=True),
),
256,
),
build_feature_map_block(
(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3), # conv11_2
nn.ReLU(inplace=True),
),
256,
),
])

return MultiFeatureMap(feature_maps)
return SSDVGGFeatureExtractor(backbone)


def ssd_vgg16(pretrained=False, progress=True,
Expand Down

0 comments on commit 03bc52c

Please sign in to comment.