Skip to content

Commit

Permalink
add float support to utils.draw_bounding_boxes() (#8328)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
3 people authored Apr 19, 2024
1 parent 0367c21 commit 96640af
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
20 changes: 17 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ def test_draw_boxes():
assert_equal(img, img_cp)


@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)

assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8

img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)

assert img_float is not out_float
assert out_float.is_floating_point()

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
Expand Down Expand Up @@ -152,7 +169,6 @@ def test_draw_boxes_grayscale():

def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
Expand All @@ -162,8 +178,6 @@ def test_draw_invalid_boxes():

with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
Expand Down
23 changes: 15 additions & 8 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ def draw_bounding_boxes(
) -> torch.Tensor:

"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
Draws bounding boxes on given RGB image.
The image values should be uint8 in [0, 255] or float in [0, 1].
If fill is True, Resulting Tensor should be saved as PNG image.
Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
Expand All @@ -188,13 +188,14 @@ def draw_bounding_boxes(
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
"""
import torchvision.transforms.v2.functional as F # noqa

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}:
Expand Down Expand Up @@ -230,8 +231,11 @@ def draw_bounding_boxes(
if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1))

ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
original_dtype = image.dtype
if original_dtype.is_floating_point:
image = F.to_dtype(image, dtype=torch.uint8, scale=True)

img_to_draw = F.to_pil_image(image)
img_boxes = boxes.to(torch.int64).tolist()

if fill:
Expand All @@ -250,7 +254,10 @@ def draw_bounding_boxes(
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
out = F.pil_to_tensor(img_to_draw)
if original_dtype.is_floating_point:
out = F.to_dtype(out, dtype=original_dtype, scale=True)
return out


@torch.no_grad()
Expand Down

0 comments on commit 96640af

Please sign in to comment.