Skip to content

Commit

Permalink
Implement forward methods + temp workarounds to inherit from retina.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Mar 12, 2021
1 parent 34237e4 commit 22d1dcd
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 26 deletions.
17 changes: 10 additions & 7 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,15 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):

return detections

def _anchors_per_level(self, features, HWA):
# recover level sizes
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
HW = 0
for v in num_anchors_per_level:
HW += v
A = HWA // HW
return [hw * A for hw in num_anchors_per_level]

def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Expand Down Expand Up @@ -531,13 +540,7 @@ def forward(self, images, targets=None):
losses = self.compute_loss(targets, head_outputs, anchors)
else:
# recover level sizes
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
HW = 0
for v in num_anchors_per_level:
HW += v
HWA = head_outputs['cls_logits'].size(1)
A = HWA // HW
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
num_anchors_per_level = self._anchors_per_level(features, head_outputs['cls_logits'].size(1))

# split outputs per level
split_head_outputs: Dict[str, List[Tensor]] = {}
Expand Down
78 changes: 61 additions & 17 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,69 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes:
self.regression_head = SSDRegressionHead(in_channels, num_anchors)


class SSDClassificationHead(nn.Module):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
class SSDScoringHead(nn.Module):
def __init__(self, module_list: nn.ModuleList, num_columns: int):
super().__init__()
self.cls_logits = nn.ModuleList()
self.module_list = module_list
self.num_columns = num_columns

def get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.module_list[idx](x),
but torchscript doesn't support this yet
"""
num_blocks = len(self.module_list)
if idx < 0:
idx += num_blocks
i = 0
out = x
for module in self.module_list:
if i == idx:
out = module(x)
i += 1
return out

def forward(self, x: List[Tensor]) -> Tensor:
all_results = []

for i, features in enumerate(x):
results = self.get_result_from_module_list(features, i)

# Permute output from (N, A * K, H, W) to (N, HWA, K).
N, _, H, W = results.shape
results = results.view(N, -1, self.num_columns, H, W)
results = results.permute(0, 3, 4, 1, 2)
results = results.reshape(N, -1, self.num_columns) # Size=(N, HWA, K)

all_results.append(results)

return torch.cat(all_results, dim=1)


class SSDClassificationHead(SSDScoringHead):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
cls_logits = nn.ModuleList()
for channels, anchors in zip(in_channels, num_anchors):
self.cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
super().__init__(cls_logits, num_classes)

def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor],
matched_idxs: List[Tensor]) -> Tensor:
pass

def forward(self, x: List[Tensor]) -> Tensor:
pass


class SSDRegressionHead(nn.Module):
class SSDRegressionHead(SSDScoringHead):
def __init__(self, in_channels: List[int], num_anchors: List[int]):
super().__init__()
self.bbox_reg = nn.ModuleList()
bbox_reg = nn.ModuleList()
for channels, anchors in zip(in_channels, num_anchors):
self.bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
super().__init__(bbox_reg, 4)

def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor],
matched_idxs: List[Tensor]) -> Tensor:
pass

def forward(self, x: List[Tensor]) -> Tensor:
pass


class SSD(RetinaNet):
def __init__(self, backbone: nn.Module, num_classes: int,
Expand Down Expand Up @@ -80,8 +114,8 @@ def __init__(self, backbone: nn.Module, num_classes: int,
self.backbone = backbone

# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
num_anchors = [2 + 2 * len(r) for r in aspect_ratios]
self.head = SSDHead(out_channels, num_anchors, num_classes)
self.num_anchors = [2 + 2 * len(r) for r in aspect_ratios]
self.head = SSDHead(out_channels, self.num_anchors, num_classes)

self.anchor_generator = DBoxGenerator(size, feature_map_sizes, aspect_ratios)

Expand All @@ -97,7 +131,8 @@ def __init__(self, backbone: nn.Module, num_classes: int,
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(size, size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(size, size, image_mean, image_std,
size_divisible=1) # TODO: Discuss/refactor this workaround

self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
Expand All @@ -107,6 +142,15 @@ def __init__(self, backbone: nn.Module, num_classes: int,
# used only on torchscript mode
self._has_warned = False

def _anchors_per_level(self, features, HWA):
# TODO: Discuss/refactor this workaround
num_anchors_per_level = [x.size(2) * x.size(3) * anchors for x, anchors in zip(features, self.num_anchors)]
HW = 0
for v in num_anchors_per_level:
HW += v
A = HWA // HW
return [hw * A for hw in num_anchors_per_level]

def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor],
anchors: List[Tensor]) -> Dict[str, Tensor]:
pass
Expand Down Expand Up @@ -203,7 +247,7 @@ def ssd_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int
pretrained_backbone = False

backbone = _vgg_backbone("vgg16", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = SSD(backbone, num_classes, **kwargs)
model = SSD(backbone, num_classes, **kwargs) # TODO: fix initializations in all new layers
if pretrained:
pass # TODO: load pre-trained COCO weights
return model
5 changes: 3 additions & 2 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ class GeneralizedRCNNTransform(nn.Module):
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
"""

def __init__(self, min_size, max_size, image_mean, image_std):
def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32):
super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size
self.image_mean = image_mean
self.image_std = image_std
self.size_divisible = size_divisible

def forward(self,
images, # type: List[Tensor]
Expand Down Expand Up @@ -107,7 +108,7 @@ def forward(self,
targets[i] = target_index

image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images)
images = self.batch_images(images, size_divisible=self.size_divisible)
image_sizes_list: List[Tuple[int, int]] = []
for image_size in image_sizes:
assert len(image_size) == 2
Expand Down

0 comments on commit 22d1dcd

Please sign in to comment.