diff --git a/test/test_utils.py b/test/test_utils.py index cb6aa7cf6d1..020327d67fd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -210,14 +210,11 @@ def test_draw_segmentation_masks(colors, alpha, device): num_masks, h, w = 2, 100, 100 dtype = torch.uint8 img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) - masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device) + masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device) + masks[0, 10:20, 10:20] = True + masks[1, 15:25, 15:25] = True - # For testing we enforce that there's no overlap between the masks. The - # current behaviour is that the last mask's color will take priority when - # masks overlap, but this makes testing slightly harder, so we don't really - # care overlap = masks[0] & masks[1] - masks[:, overlap] = False out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) assert out.dtype == dtype @@ -239,12 +236,15 @@ def test_draw_segmentation_masks(colors, alpha, device): color = torch.tensor(color, dtype=dtype, device=device) if alpha == 1: - assert (out[:, mask] == color[:, None]).all() + assert (out[:, mask & ~overlap] == color[:, None]).all() elif alpha == 0: - assert (out[:, mask] == img[:, mask]).all() + assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all() - interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype) - torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) + interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype) + torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0) + + interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype) + torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0) def test_draw_segmentation_masks_dtypes(): diff --git a/torchvision/utils.py b/torchvision/utils.py index 630eada5cbc..530f938f7a7 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -299,6 +299,7 @@ def draw_segmentation_masks( raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] + overlapping_masks = masks.sum(dim=0) > 1 if num_masks == 0: warnings.warn("masks doesn't contain any mask. No mask was drawn") @@ -315,6 +316,8 @@ def draw_segmentation_masks( for mask, color in zip(masks, colors): img_to_draw[:, mask] = color[:, None] + img_to_draw[:, overlapping_masks] = 0 + out = image * (1 - alpha) + img_to_draw * alpha # Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype return out.to(original_dtype)