Skip to content

Commit

Permalink
Refactor grid default boxes with torch meshgrid (#3799)
Browse files Browse the repository at this point in the history
* Refactor grid default boxes with torch.meshgrid

* Fix torch jit tracing

* Only doing the list multiplication once

Co-authored-by: Francisco Massa <[email protected]>

* Make grid_default_box private as suggested

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Replace list multiplication with torch.repeat

* Move the clipping into _grid_default_boxes to accelerate

Co-authored-by: Francisco Massa <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
3 people authored May 11, 2021
1 parent 5dd7dfe commit 48441cc
Showing 1 changed file with 40 additions and 25 deletions.
65 changes: 40 additions & 25 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,59 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_
else:
self.scales = scales

self._wh_pairs = []
self._wh_pairs = self._generate_wh_pairs(num_outputs)

def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")) -> List[Tensor]:
_wh_pairs: List[Tensor] = []
for k in range(num_outputs):
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
s_k = self.scales[k]
s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
wh_pairs = [(s_k, s_k), (s_prime_k, s_prime_k)]
wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]

# Adding 2 pairs for each aspect ratio of the feature map k
for ar in self.aspect_ratios[k]:
sq_ar = math.sqrt(ar)
w = self.scales[k] * sq_ar
h = self.scales[k] / sq_ar
wh_pairs.extend([(w, h), (h, w)])
wh_pairs.extend([[w, h], [h, w]])

self._wh_pairs.append(wh_pairs)
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs

def num_anchors_per_location(self):
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios]

# Default Boxes calculation based on page 6 of SSD paper
def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int],
dtype: torch.dtype = torch.float32) -> Tensor:
default_boxes = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
if self.steps is not None:
x_f_k, y_f_k = [img_shape / self.steps[k] for img_shape in image_size]
else:
y_f_k, x_f_k = f_k

shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k
shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)

shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
# Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
_wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)

default_box = torch.cat((shifts, wh_pairs), dim=1)

default_boxes.append(default_box)

return torch.cat(default_boxes, dim=0)

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'aspect_ratios={aspect_ratios}'
Expand All @@ -203,30 +236,12 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device

# Default Boxes calculation based on page 6 of SSD paper
default_boxes: List[List[float]] = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
for j in range(f_k[0]):
if self.steps is not None:
y_f_k = image_size[1] / self.steps[k]
else:
y_f_k = float(f_k[0])
cy = (j + 0.5) / y_f_k
for i in range(f_k[1]):
if self.steps is not None:
x_f_k = image_size[0] / self.steps[k]
else:
x_f_k = float(f_k[1])
cx = (i + 0.5) / x_f_k
default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
default_boxes = default_boxes.to(device)

dboxes = []
for _ in image_list.image_sizes:
dboxes_in_image = torch.tensor(default_boxes, dtype=dtype, device=device)
if self.clip:
dboxes_in_image.clamp_(min=0, max=1)
dboxes_in_image = default_boxes
dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1)
dboxes_in_image[:, 0::2] *= image_size[1]
Expand Down

0 comments on commit 48441cc

Please sign in to comment.