diff --git a/docs/requirements.txt b/docs/requirements.txt index 44132ef3375..d2eb35aac8e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ -sphinx==3.5.4 -sphinx-gallery>=0.9.0 -sphinx-copybutton>=0.3.1 matplotlib numpy +sphinx-copybutton>=0.3.1 +sphinx-gallery>=0.9.0 +sphinx==3.5.4 -e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/docs/source/ops.rst b/docs/source/ops.rst index ecef74dd8a6..5fd4b75e59d 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -9,19 +9,20 @@ torchvision.ops All operators have native support for TorchScript. -.. autofunction:: nms .. autofunction:: batched_nms -.. autofunction:: remove_small_boxes -.. autofunction:: clip_boxes_to_image -.. autofunction:: box_convert .. autofunction:: box_area +.. autofunction:: box_convert .. autofunction:: box_iou +.. autofunction:: clip_boxes_to_image +.. autofunction:: deform_conv2d .. autofunction:: generalized_box_iou -.. autofunction:: roi_align +.. autofunction:: masks_to_boxes +.. autofunction:: nms .. autofunction:: ps_roi_align -.. autofunction:: roi_pool .. autofunction:: ps_roi_pool -.. autofunction:: deform_conv2d +.. autofunction:: remove_small_boxes +.. autofunction:: roi_align +.. autofunction:: roi_pool .. autofunction:: sigmoid_focal_loss .. autofunction:: stochastic_depth diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py new file mode 100644 index 00000000000..2decefcc815 --- /dev/null +++ b/gallery/plot_repurposing_annotations.py @@ -0,0 +1,75 @@ +""" +======================= +Repurposing annotations +======================= + +The following example illustrates the operations available in the torchvision.ops module for repurposing object +localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation +methods into bounding boxes used by object detection methods). +""" +import os.path + +import PIL.Image +import matplotlib.patches +import matplotlib.pyplot +import numpy +import torch +from torchvision.ops import masks_to_boxes + +ASSETS_DIRECTORY = "../test/assets" + +matplotlib.pyplot.rcParams["savefig.bbox"] = "tight" + +#################################### +# Masks +# ----- +# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, +# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape: +# +# (objects, height, width) +# +# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly +# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape +# of your masks annotation has the following shape: +# +# (4, 224, 224). +# +# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object +# localization tasks. +# +# Masks to bounding boxes +# ---------------------------------------- +# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be +# used in methods like Faster RCNN and YOLO. + +with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: + masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int) + + for index in range(image.n_frames): + image.seek(index) + + frame = numpy.array(image) + + masks[index] = torch.tensor(frame) + +bounding_boxes = masks_to_boxes(masks) + +figure = matplotlib.pyplot.figure() + +a = figure.add_subplot(121) +b = figure.add_subplot(122) + +labeled_image = torch.sum(masks, 0) + +a.imshow(labeled_image) +b.imshow(labeled_image) + +for bounding_box in bounding_boxes: + x0, y0, x1, y1 = bounding_box + + rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none") + + b.add_patch(rectangle) + +a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) +b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) diff --git a/test/assets/labeled_image.png b/test/assets/labeled_image.png new file mode 100644 index 00000000000..9d163243773 Binary files /dev/null and b/test/assets/labeled_image.png differ diff --git a/test/assets/masks.tiff b/test/assets/masks.tiff new file mode 100644 index 00000000000..7a8efc6dd0e Binary files /dev/null and b/test/assets/masks.tiff differ diff --git a/test/test_masks_to_boxes.py b/test/test_masks_to_boxes.py new file mode 100644 index 00000000000..7182ebcae9f --- /dev/null +++ b/test/test_masks_to_boxes.py @@ -0,0 +1,34 @@ +import os.path + +import PIL.Image +import numpy +import torch + +from torchvision.ops import masks_to_boxes + +ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") + + +def test_masks_to_boxes(): + with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: + masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int) + + for index in range(image.n_frames): + image.seek(index) + + frame = numpy.array(image) + + masks[index] = torch.tensor(frame) + + expected = torch.tensor( + [[127, 2, 165, 40], + [2, 50, 44, 92], + [56, 63, 98, 100], + [139, 68, 175, 104], + [160, 112, 198, 145], + [49, 138, 99, 182], + [108, 148, 152, 213]], + dtype=torch.int32 + ) + + torch.testing.assert_close(masks_to_boxes(masks), expected) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 606c27abcbe..33b35dc93b9 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,4 +1,5 @@ -from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou +from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou, \ + masks_to_boxes from .boxes import box_convert from .deform_conv import deform_conv2d, DeformConv2d from .roi_align import roi_align, RoIAlign diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index c1f176f4da9..6dafcf1c190 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -297,3 +297,35 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: areai = whi[:, :, 0] * whi[:, :, 1] return iou - (areai - union) / areai + + +def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: + """ + Compute the bounding boxes around the provided masks + + Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + masks (Tensor[N, H, W]): masks to transform where N is the number of + masks and (H, W) are the spatial dimensions. + + Returns: + Tensor[N, 4]: bounding boxes + """ + if masks.numel() == 0: + return torch.zeros((0, 4)) + + n = masks.shape[0] + + bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int) + + for index, mask in enumerate(masks): + y, x = torch.where(masks[index] != 0) + + bounding_boxes[index, 0] = torch.min(x) + bounding_boxes[index, 1] = torch.min(y) + bounding_boxes[index, 2] = torch.max(x) + bounding_boxes[index, 3] = torch.max(y) + + return bounding_boxes