Skip to content

Commit

Permalink
Cleanup AnchorGenerator and PostProcess (#203)
Browse files Browse the repository at this point in the history
* Cleanup Anchor configuration mechanism

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the latest compatibility issues

* Fix missing outputs

* Fix docstrings

* Fix anchors in AnchorGenerator._generate_shifts

* Fix TestAnchorGenerator

* Fix test_anchor_generator

* Fix test_postprocessors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Fixing YOLOTRTModule

* Fix pylint

* Minor fix

* Fix tensor.stride in AnchorGenerator._generate_shifts

* Cleanup codes

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zhiqwang and pre-commit-ci[bot] authored Dec 20, 2021
1 parent b4fac50 commit 5dd25c3
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 171 deletions.
43 changes: 24 additions & 19 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ def _get_anchor_grids(use_p6: bool):
]
return anchor_grids

def _compute_num_anchors(self, height, width, use_p6: bool):
def _compute_anchors(self, height, width, use_p6: bool):
strides = self._get_strides(use_p6)
num_anchors = 0
anchors_num = len(strides)
anchors_shape = []
for s in strides:
num_anchors += (height // s) * (width // s)
return num_anchors * 3
anchors_shape.append((height // s, width // s))
return anchors_num, anchors_shape

def _get_feature_shapes(self, height, width, width_multiple=0.5, use_p6=False):
in_channels = self._get_in_channels(width_multiple, use_p6)
Expand Down Expand Up @@ -238,12 +239,14 @@ def test_anchor_generator(self, width_multiple, use_p6, batch_size, height, widt
)
model = self._init_test_anchor_generator(use_p6)
anchors = model(feature_maps)
expected_num_anchors = self._compute_num_anchors(height, width, use_p6)
expected_anchors_num, expected_anchors_shape = self._compute_anchors(height, width, use_p6)

assert len(anchors) == 2
assert len(anchors[0]) == len(anchors[1]) == expected_anchors_num
for i in range(expected_anchors_num):
assert tuple(anchors[0][i].shape) == (1, 3, *(expected_anchors_shape[i]), 2)
assert tuple(anchors[1][i].shape) == (1, 3, *(expected_anchors_shape[i]), 2)

assert len(anchors) == 3
assert tuple(anchors[0].shape) == (expected_num_anchors, 2)
assert tuple(anchors[1].shape) == (expected_num_anchors, 1)
assert tuple(anchors[2].shape) == (expected_num_anchors, 2)
_check_jit_scriptable(model, (feature_maps,))

def _init_test_yolo_head(self, width_multiple=0.5, use_p6=False):
Expand All @@ -269,29 +272,31 @@ def test_yolo_head(self):
assert head_outputs[2].shape == target_head_outputs[2].shape
_check_jit_scriptable(model, (feature_maps,))

def _init_test_postprocessors(self):
def _init_test_postprocessors(self, strides):
score_thresh = 0.5
nms_thresh = 0.45
detections_per_img = 100
postprocessors = PostProcess(score_thresh, nms_thresh, detections_per_img)
postprocessors = PostProcess(strides, score_thresh, nms_thresh, detections_per_img)
return postprocessors

def test_postprocessors(self):
@pytest.mark.parametrize("use_p6", [False, True])
def test_postprocessors(self, use_p6):
N, H, W = 4, 416, 352
feature_maps = self._get_feature_maps(N, H, W)
head_outputs = self._get_head_outputs(N, H, W)
strides = self._get_strides(use_p6)
feature_maps = self._get_feature_maps(N, H, W, use_p6=use_p6)
head_outputs = self._get_head_outputs(N, H, W, use_p6=use_p6)

anchor_generator = self._init_test_anchor_generator()
anchors_tuple = anchor_generator(feature_maps)
model = self._init_test_postprocessors()
out = model(head_outputs, anchors_tuple)
anchor_generator = self._init_test_anchor_generator(use_p6=use_p6)
grids, shifts = anchor_generator(feature_maps)
model = self._init_test_postprocessors(strides)
out = model(head_outputs, grids, shifts)

assert len(out) == N
assert isinstance(out[0], dict)
assert isinstance(out[0]["boxes"], Tensor)
assert isinstance(out[0]["labels"], Tensor)
assert isinstance(out[0]["scores"], Tensor)
_check_jit_scriptable(model, (head_outputs, anchors_tuple))
_check_jit_scriptable(model, (head_outputs, grids, shifts))

def test_criterion(self, use_p6=False):
N, H, W = 4, 640, 640
Expand Down
18 changes: 8 additions & 10 deletions test/test_models_anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ def test_anchor_generator(self):
model.eval()
anchors = model(features)

expected_anchor_output = torch.tensor([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]])
expected_wh_output = torch.tensor([[4.0], [4.0], [4.0], [4.0]])
expected_xy_output = torch.tensor([[6.0, 14.0], [6.0, 14.0], [6.0, 14.0], [6.0, 14.0]])
expected_grids = torch.tensor([[[[[0.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]]]])
expected_shifts = torch.tensor([[[[[6.0, 14.0], [6.0, 14.0]], [[6.0, 14.0], [6.0, 14.0]]]]])

assert len(anchors) == 3
assert tuple(anchors[0].shape) == (4, 2)
assert tuple(anchors[1].shape) == (4, 1)
assert tuple(anchors[2].shape) == (4, 2)
assert len(anchors) == 2
assert len(anchors[0]) == len(anchors[1]) == 1
assert tuple(anchors[0][0].shape) == (1, 1, 2, 2, 2)
assert tuple(anchors[1][0].shape) == (1, 1, 2, 2, 2)

torch.testing.assert_close(anchors[0], expected_anchor_output, rtol=0, atol=0)
torch.testing.assert_close(anchors[1], expected_wh_output, rtol=0, atol=0)
torch.testing.assert_close(anchors[2], expected_xy_output, rtol=0, atol=0)
torch.testing.assert_close(anchors[0][0], expected_grids)
torch.testing.assert_close(anchors[1][0], expected_shifts)
43 changes: 23 additions & 20 deletions yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import torch
from torch import nn, Tensor
from torchvision.ops import box_convert, box_iou
from torchvision.ops import box_iou


def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction from model
Evaluate intersection over union (IOU) for target from dataset and
output prediction from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
Expand All @@ -34,8 +35,7 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->

def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor:
"""
Encode a set of anchors with respect to some
reference boxes
Encode a set of anchors with respect to some reference boxes
Args:
reference_boxes (Tensor): reference boxes
Expand All @@ -52,23 +52,24 @@ def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor:

def decode_single(
rel_codes: Tensor,
anchors_tuple: Tuple[Tensor, Tensor, Tensor],
) -> Tensor:
grid: Tensor,
shift: Tensor,
stride: int,
) -> Tuple[Tensor, Tensor]:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Arguments:
rel_codes (Tensor): encoded boxes
anchors_tupe (Tensor, Tensor, Tensor): reference boxes.
Args:
rel_codes (Tensor): Encoded boxes
grid (Tensor): Anchor grids
shift (Tensor): Anchor shifts
stride (int): Stride
"""
pred_xy = (rel_codes[..., 0:2] * 2.0 - 0.5 + grid) * stride
pred_wh = (rel_codes[..., 2:4] * 2.0) ** 2 * shift

pred_wh = (rel_codes[..., 0:2] * 2.0 + anchors_tuple[0]) * anchors_tuple[1] # wh
pred_xy = (rel_codes[..., 2:4] * 2) ** 2 * anchors_tuple[2] # xy
pred_boxes = torch.cat([pred_wh, pred_xy], dim=1)
pred_boxes = box_convert(pred_boxes, in_fmt="cxcywh", out_fmt="xyxy")

return pred_boxes
return pred_xy, pred_wh


def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-7):
Expand Down Expand Up @@ -99,8 +100,10 @@ def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-

iou = inter / union

cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
# convex (smallest enclosing box) width
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
# convex height
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
# Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
rho2 = (
Expand Down Expand Up @@ -149,7 +152,7 @@ def forward(self, pred, logit):

if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
if self.reduction == "sum":
return loss.sum()
else: # 'none'
return loss
# 'none'
return loss
104 changes: 38 additions & 66 deletions yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,61 @@


class AnchorGenerator(nn.Module):
def __init__(
self,
strides: List[int],
anchor_grids: List[List[float]],
):
def __init__(self, strides: List[int], anchor_grids: List[List[float]]):

super().__init__()
assert len(strides) == len(anchor_grids)
self.num_anchors = len(anchor_grids[0]) // 2
self.strides = strides
self.anchor_grids = anchor_grids
self.num_layers = len(anchor_grids)
self.num_anchors = len(anchor_grids[0]) // 2

def set_wh_weights(
def _generate_grids(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:

wh_weights = []

for size, stride in zip(grid_sizes, self.strides):
grid_height, grid_width = size
stride = torch.as_tensor([stride], dtype=dtype, device=device)
stride = stride.view(-1, 1)
stride = stride.repeat(1, grid_height * grid_width * self.num_anchors)
stride = stride.reshape(-1, 1)
wh_weights.append(stride)
) -> List[Tensor]:

return torch.cat(wh_weights)

def set_xy_weights(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:
grids = []
for height, width in grid_sizes:
# For output anchor, compute [x_center, y_center, x_center, y_center]
widths = torch.arange(width, dtype=torch.int32, device=device).to(dtype=dtype)
heights = torch.arange(height, dtype=torch.int32, device=device).to(dtype=dtype)

xy_weights = []
shift_y, shift_x = torch.meshgrid(heights, widths)

for size, anchor_grid in zip(grid_sizes, self.anchor_grids):
grid_height, grid_width = size
anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)
anchor_grid = anchor_grid.view(-1, 2)
anchor_grid = anchor_grid.repeat(1, grid_height * grid_width)
anchor_grid = anchor_grid.reshape(-1, 2)
xy_weights.append(anchor_grid)
grid = torch.stack((shift_x, shift_y), 2).expand((1, self.num_anchors, height, width, 2))
grids.append(grid)

return torch.cat(xy_weights)
return grids

def grid_anchors(
def _generate_shifts(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:

anchors = []

for size in grid_sizes:
grid_height, grid_width = size

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device).to(dtype=dtype)
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device).to(dtype=dtype)

shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)

shifts = torch.stack((shift_x, shift_y), dim=2)
shifts = shifts.view(1, grid_height, grid_width, 2)
shifts = shifts.repeat(self.num_anchors, 1, 1, 1)
shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)
shifts = shifts.reshape(-1, 2)

anchors.append(shifts)

return torch.cat(anchors)

def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
) -> List[Tensor]:

anchors = torch.tensor(self.anchor_grids, dtype=dtype, device=device)
strides = torch.tensor(self.strides, dtype=dtype, device=device)
anchors = anchors.view(self.num_layers, -1, 2) / strides.view(-1, 1, 1)

shifts = []
for i, (height, width) in enumerate(grid_sizes):
shift = (
(anchors[i].clone() * self.strides[i])
.view((1, self.num_anchors, 1, 1, 2))
.expand((1, self.num_anchors, height, width, 2))
.contiguous()
.to(dtype=dtype)
)
shifts.append(shift)
return shifts

def forward(self, feature_maps: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
dtype, device = feature_maps[0].dtype, feature_maps[0].device

wh_weights = self.set_wh_weights(grid_sizes, dtype, device)
xy_weights = self.set_xy_weights(grid_sizes, dtype, device)
anchors = self.grid_anchors(grid_sizes, dtype, device)

return anchors, wh_weights, xy_weights
grids = self._generate_grids(grid_sizes, dtype=dtype, device=device)
shifts = self._generate_shifts(grid_sizes, dtype=dtype, device=device)
return grids, shifts
Loading

0 comments on commit 5dd25c3

Please sign in to comment.