Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor grid default boxes with torch meshgrid #3799

Merged
merged 8 commits into from
May 11, 2021
65 changes: 40 additions & 25 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,59 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_
else:
self.scales = scales

self._wh_pairs = []
self._wh_pairs = self._generate_wh_pairs(num_outputs)

def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")) -> List[Tensor]:
_wh_pairs: List[Tensor] = []
for k in range(num_outputs):
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
s_k = self.scales[k]
s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
wh_pairs = [(s_k, s_k), (s_prime_k, s_prime_k)]
wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
datumbox marked this conversation as resolved.
Show resolved Hide resolved

# Adding 2 pairs for each aspect ratio of the feature map k
for ar in self.aspect_ratios[k]:
sq_ar = math.sqrt(ar)
w = self.scales[k] * sq_ar
h = self.scales[k] / sq_ar
wh_pairs.extend([(w, h), (h, w)])
wh_pairs.extend([[w, h], [h, w]])

self._wh_pairs.append(wh_pairs)
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs

def num_anchors_per_location(self):
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios]

# Default Boxes calculation based on page 6 of SSD paper
def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int],
dtype: torch.dtype = torch.float32) -> Tensor:
default_boxes = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
if self.steps is not None:
x_f_k, y_f_k = [img_shape / self.steps[k] for img_shape in image_size]
else:
y_f_k, x_f_k = f_k

shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k
shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k
Copy link
Contributor Author

@zhiqwang zhiqwang May 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the default_boxes are generated on the CPU device and then migrated to the CUDA device. I've tried the following method to generate the default_boxes directly on CUDA device, but It will take longer than the for-loop method.

shifts_x = (torch.arange(0, f_k[1], device=torch.device('cuda'), dtype=dtype) + 0.5) / x_f_k
shifts_y = (torch.arange(0, f_k[0], device=torch.device('cuda'), dtype=dtype) + 0.5) / y_f_k

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've hit cases in the past where micro-benchmarks on exactly this part of the code could be faster if running on the CPU, but would present significant slowdowns when training on multiple GPUs. Even if this might be slower on micro-benchmarks if run on a single GPU, it might still be faster on multiple GPUs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa Thus you recommend passing the target device to this method and putting them right away in there, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave it as is for now, but I would create a follow-up issue to benchmark this and the other configuration on multiple GPUs to verify

Copy link
Contributor Author

@zhiqwang zhiqwang May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I've tested the inferring consumption time of the total COCO eval datasets betweed this two default boxes generations methods on different device, the consumption time of these two is very similar.

Validated with (using one card):

CUDA_VISIBLE_DEVICES=0 python train.py --dataset coco --model ssd300_vgg16 \
    --batch-size 16 --pretrained --test-only

shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)

shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
# Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
_wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)

default_box = torch.cat((shifts, wh_pairs), dim=1)

default_boxes.append(default_box)

return torch.cat(default_boxes, dim=0)

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'aspect_ratios={aspect_ratios}'
Expand All @@ -203,30 +236,12 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device

# Default Boxes calculation based on page 6 of SSD paper
default_boxes: List[List[float]] = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
for j in range(f_k[0]):
if self.steps is not None:
y_f_k = image_size[1] / self.steps[k]
else:
y_f_k = float(f_k[0])
cy = (j + 0.5) / y_f_k
for i in range(f_k[1]):
if self.steps is not None:
x_f_k = image_size[0] / self.steps[k]
else:
x_f_k = float(f_k[1])
cx = (i + 0.5) / x_f_k
default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
default_boxes = default_boxes.to(device)

dboxes = []
for _ in image_list.image_sizes:
dboxes_in_image = torch.tensor(default_boxes, dtype=dtype, device=device)
if self.clip:
dboxes_in_image.clamp_(min=0, max=1)
dboxes_in_image = default_boxes
datumbox marked this conversation as resolved.
Show resolved Hide resolved
dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1)
dboxes_in_image[:, 0::2] *= image_size[1]
Expand Down