Skip to content

Commit

Permalink
masks_to_bounding_boxes op (#4290)
Browse files Browse the repository at this point in the history
* ops.masks_to_bounding_boxes

* test fixtures

* unit test

* ignore lint e201 and e202 for in-lined matrix

* ignore e121 and e241 linting rules for in-lined matrix

* draft gallery example text

* removed type annotations from pytest fixtures

* inlined fixture

* renamed masks_to_bounding_boxes to masks_to_boxes

* reformat inline array

* import cleanup

* moved masks_to_boxes into boxes module

* docstring cleanup

* updated docstring

* fix formatting issue

* gallery example

* use torch

* use torch

* use torch

* use torch

* updated docs and test

* cleanup

* updated import

* use torch

* Update gallery/plot_repurposing_annotations.py

Co-authored-by: Aditya Oke <[email protected]>

* Update gallery/plot_repurposing_annotations.py

Co-authored-by: Aditya Oke <[email protected]>

* Update gallery/plot_repurposing_annotations.py

Co-authored-by: Aditya Oke <[email protected]>

* Autodoc

* use torch instead of numpy in tests

* fix build_docs failure

* Closing quotes.

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Aditya Oke <[email protected]>
  • Loading branch information
3 people authored Sep 21, 2021
1 parent 8a83cf2 commit f0422e7
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 11 deletions.
6 changes: 3 additions & 3 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sphinx==3.5.4
sphinx-gallery>=0.9.0
sphinx-copybutton>=0.3.1
matplotlib
numpy
sphinx-copybutton>=0.3.1
sphinx-gallery>=0.9.0
sphinx==3.5.4
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
15 changes: 8 additions & 7 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ torchvision.ops
All operators have native support for TorchScript.


.. autofunction:: nms
.. autofunction:: batched_nms
.. autofunction:: remove_small_boxes
.. autofunction:: clip_boxes_to_image
.. autofunction:: box_convert
.. autofunction:: box_area
.. autofunction:: box_convert
.. autofunction:: box_iou
.. autofunction:: clip_boxes_to_image
.. autofunction:: deform_conv2d
.. autofunction:: generalized_box_iou
.. autofunction:: roi_align
.. autofunction:: masks_to_boxes
.. autofunction:: nms
.. autofunction:: ps_roi_align
.. autofunction:: roi_pool
.. autofunction:: ps_roi_pool
.. autofunction:: deform_conv2d
.. autofunction:: remove_small_boxes
.. autofunction:: roi_align
.. autofunction:: roi_pool
.. autofunction:: sigmoid_focal_loss
.. autofunction:: stochastic_depth

Expand Down
75 changes: 75 additions & 0 deletions gallery/plot_repurposing_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
=======================
Repurposing annotations
=======================
The following example illustrates the operations available in the torchvision.ops module for repurposing object
localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation
methods into bounding boxes used by object detection methods).
"""
import os.path

import PIL.Image
import matplotlib.patches
import matplotlib.pyplot
import numpy
import torch
from torchvision.ops import masks_to_boxes

ASSETS_DIRECTORY = "../test/assets"

matplotlib.pyplot.rcParams["savefig.bbox"] = "tight"

####################################
# Masks
# -----
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape:
#
# (objects, height, width)
#
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape
# of your masks annotation has the following shape:
#
# (4, 224, 224).
#
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
# localization tasks.
#
# Masks to bounding boxes
# ----------------------------------------
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be
# used in methods like Faster RCNN and YOLO.

with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)

for index in range(image.n_frames):
image.seek(index)

frame = numpy.array(image)

masks[index] = torch.tensor(frame)

bounding_boxes = masks_to_boxes(masks)

figure = matplotlib.pyplot.figure()

a = figure.add_subplot(121)
b = figure.add_subplot(122)

labeled_image = torch.sum(masks, 0)

a.imshow(labeled_image)
b.imshow(labeled_image)

for bounding_box in bounding_boxes:
x0, y0, x1, y1 = bounding_box

rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none")

b.add_patch(rectangle)

a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
Binary file added test/assets/labeled_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/masks.tiff
Binary file not shown.
34 changes: 34 additions & 0 deletions test/test_masks_to_boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os.path

import PIL.Image
import numpy
import torch

from torchvision.ops import masks_to_boxes

ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")


def test_masks_to_boxes():
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)

for index in range(image.n_frames):
image.seek(index)

frame = numpy.array(image)

masks[index] = torch.tensor(frame)

expected = torch.tensor(
[[127, 2, 165, 40],
[2, 50, 44, 92],
[56, 63, 98, 100],
[139, 68, 175, 104],
[160, 112, 198, 145],
[49, 138, 99, 182],
[108, 148, 152, 213]],
dtype=torch.int32
)

torch.testing.assert_close(masks_to_boxes(masks), expected)
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou, \
masks_to_boxes
from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d
from .roi_align import roi_align, RoIAlign
Expand Down
32 changes: 32 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,35 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
areai = whi[:, :, 0] * whi[:, :, 1]

return iou - (areai - union) / areai


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of
masks and (H, W) are the spatial dimensions.
Returns:
Tensor[N, 4]: bounding boxes
"""
if masks.numel() == 0:
return torch.zeros((0, 4))

n = masks.shape[0]

bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int)

for index, mask in enumerate(masks):
y, x = torch.where(masks[index] != 0)

bounding_boxes[index, 0] = torch.min(x)
bounding_boxes[index, 1] = torch.min(y)
bounding_boxes[index, 2] = torch.max(x)
bounding_boxes[index, 3] = torch.max(y)

return bounding_boxes

0 comments on commit f0422e7

Please sign in to comment.