Skip to content

Commit

Permalink
Replace type T with accumulator.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Feb 11, 2021
1 parent 51500c7 commit 4807308
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
11 changes: 11 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,17 @@ def test_autocast(self):
with torch.cuda.amp.autocast():
self.test_nms_cuda(dtype=dtype)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda_float16(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()

iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16)))


class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
Expand Down
13 changes: 7 additions & 6 deletions torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ __device__ inline bool devIoU(
T const* const a,
T const* const b,
const float threshold) {
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
T interS = width * height;
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
acc_T left = max(a[0], b[0]), right = min(a[2], b[2]);
acc_T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
acc_T width = max(right - left, (acc_T)0), height = max(bottom - top, (acc_T)0);
acc_T interS = width * height;
acc_T Sa = (a[2] - a[0]) * (a[3] - a[1]);
acc_T Sb = (b[2] - b[0]) * (b[3] - b[1]);
return (interS / (Sa + Sb - interS)) > threshold;
}

Expand Down

0 comments on commit 4807308

Please sign in to comment.