-
Notifications
You must be signed in to change notification settings - Fork 0
/
anchor_generator
114 lines (102 loc) · 4.07 KB
/
anchor_generator
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#mmdetection/mmdet/core/anchor/anchor_generator.py
import torch
class AnchorGenerator(object):
"""
Examples:
>>> from mmdet.core import AnchorGenerator
>>> self = AnchorGenerator(9, [1.], [1.])
>>> all_anchors = self.grid_anchors((2, 2), device='cpu')
>>> print(all_anchors)
tensor([[ 0., 0., 8., 8.],
[16., 0., 24., 8.],
[ 0., 16., 8., 24.],
[16., 16., 24., 24.]])
"""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
self.base_size = base_size
#Change here
if len(scales) > 1:
self.multi_scale = True
else:
self.multi_scale = False
#change over
self.scales = torch.Tensor(scales)
self.ratios = torch.Tensor(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
@property
def num_base_anchors(self):
return self.base_anchors.size(0)
def gen_base_anchors(self):
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = torch.sqrt(self.ratios)
w_ratios = 1 / h_ratios
#change here
# Multi Scale:
if self.multi_scale:
if self.scale_major:
ws = (w * w_ratios * self.scales).view(-1)
hs = (h * h_ratios * self.scales).view(-1)
else:
ws = (w * self.scales * w_ratios).view(-1)
hs = (h * self.scales * h_ratios).view(-1)
else:
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)
#change over
# yapf: disable
base_anchors = torch.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
dim=-1).round()
# yapf: enable
return base_anchors
def _meshgrid(self, x, y, row_major=True):
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
if row_major:
return xx, yy
else:
return yy, xx
def grid_anchors(self, featmap_size, stride=16, device='cuda'):
base_anchors = self.base_anchors.to(device)
feat_h, feat_w = featmap_size
shift_x = torch.arange(0, feat_w, device=device) * stride
shift_y = torch.arange(0, feat_h, device=device) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors
def valid_flags(self, featmap_size, valid_size, device='cuda'):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = valid[:, None].expand(
valid.size(0), self.num_base_anchors).contiguous().view(-1)
return valid