Skip to content

Commit

Permalink
Fix NMS and IoU overflows for fp16 (#3383)
Browse files Browse the repository at this point in the history
* Replace type T with accumulator.

* Upcast tensors of box ops to avoid overflow in multiplications.
  • Loading branch information
datumbox authored Feb 15, 2021
1 parent af97ec2 commit f04e9cb
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 39 deletions.
107 changes: 73 additions & 34 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,18 @@ def test_autocast(self):
with torch.cuda.amp.autocast():
self.test_nms_cuda(dtype=dtype)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda_float16(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()

iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16)))


class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
Expand Down Expand Up @@ -829,48 +841,75 @@ def test_bbox_convert_jit(self):

class BoxAreaTester(unittest.TestCase):
def test_box_area(self):
# A bounding box of area 10000 and a degenerate case
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
expected = torch.tensor([10000, 0])
calc_area = ops.box_area(box_tensor)
assert calc_area.size() == torch.Size([2])
assert calc_area.dtype == box_tensor.dtype
assert torch.all(torch.eq(calc_area, expected)).item() is True
def area_check(box, expected, tolerance=1e-4):
out = ops.box_area(box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()

# Check for int boxes
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
expected = torch.tensor([10000, 0])
area_check(box_tensor, expected)

# Check for float32 and float64 boxes
for dtype in [torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
area_check(box_tensor, expected, tolerance=0.05)

# Check for float16 box
box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5],
[285.25, 188.75, 1192.0, 851.0],
[279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16)
expected = torch.tensor([605113.875, 600495.1875, 592247.25])
area_check(box_tensor, expected)


class BoxIouTester(unittest.TestCase):
def test_iou(self):
# Boxes to test Iou
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)

# Expected IoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])

out = ops.box_iou(boxes1, boxes2)

# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3])
tolerance = 1e-4
assert ((out - expected).abs().max() < tolerance).item() is True
def iou_check(box, expected, tolerance=1e-4):
out = ops.box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()

# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
iou_check(box, expected)

# Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)


class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self):
# Test Generalized IoU
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)

# Expected gIoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611],
[-0.7778, -0.8611, 1.0]])

out = ops.generalized_box_iou(boxes1, boxes2)

# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3])
tolerance = 1e-4
assert ((out - expected).abs().max() < tolerance).item() is True
def gen_iou_check(box, expected, tolerance=1e-4):
out = ops.generalized_box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()

# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
gen_iou_check(box, expected)

# Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)


if __name__ == '__main__':
Expand Down
8 changes: 5 additions & 3 deletions torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
Expand All @@ -20,9 +21,10 @@ __device__ inline bool devIoU(
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
T interS = width * height;
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
acc_T interS = (acc_T)width * height;
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
return (interS / (Sa + Sb - interS)) > threshold;
}

Expand Down
13 changes: 11 additions & 2 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes


def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()


def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
Expand All @@ -182,6 +190,7 @@ def box_area(boxes: Tensor) -> Tensor:
Returns:
area (Tensor[N]): area for each box
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


Expand All @@ -194,7 +203,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter
Expand Down Expand Up @@ -247,7 +256,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

whi = (rbi - lti).clamp(min=0) # [N,M,2]
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
areai = whi[:, :, 0] * whi[:, :, 1]

return iou - (areai - union) / areai

0 comments on commit f04e9cb

Please sign in to comment.