Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

masks_to_bounding_boxes op #4290

Merged
merged 40 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cf51379
ops.masks_to_bounding_boxes
0x00b1 Aug 17, 2021
c67e035
test fixtures
0x00b1 Aug 18, 2021
3830dd1
unit test
0x00b1 Aug 18, 2021
926d444
Merge branch 'master' into issues/3960
0x00b1 Aug 18, 2021
f777416
ignore lint e201 and e202 for in-lined matrix
0x00b1 Aug 18, 2021
cd46aa7
ignore e121 and e241 linting rules for in-lined matrix
0x00b1 Aug 18, 2021
712131e
draft gallery example text
0x00b1 Aug 18, 2021
b6f5c42
removed type annotations from pytest fixtures
0x00b1 Aug 31, 2021
b555c68
inlined fixture
0x00b1 Aug 31, 2021
fc26f3a
renamed masks_to_bounding_boxes to masks_to_boxes
0x00b1 Aug 31, 2021
c4d3045
reformat inline array
0x00b1 Aug 31, 2021
4589951
import cleanup
0x00b1 Aug 31, 2021
6b19d67
moved masks_to_boxes into boxes module
0x00b1 Sep 1, 2021
c6c89ec
docstring cleanup
0x00b1 Sep 1, 2021
16a99a9
updated docstring
0x00b1 Sep 15, 2021
7115320
fix formatting issue
0x00b1 Sep 15, 2021
f4796d2
Merge branch 'main' into issues/3960
datumbox Sep 15, 2021
a070133
Merge branch 'master' of https://github.com/pytorch/vision into issue…
0x00b1 Sep 15, 2021
0131db3
Merge branch 'issues/3960' of https://github.com/0x00b1/vision into i…
0x00b1 Sep 15, 2021
0a23bcf
gallery example
0x00b1 Sep 17, 2021
db8fb7b
use torch
0x00b1 Sep 17, 2021
f7a2c1e
use torch
0x00b1 Sep 17, 2021
c7dfcdf
use torch
0x00b1 Sep 17, 2021
5e6198a
use torch
0x00b1 Sep 17, 2021
7c78271
updated docs and test
0x00b1 Sep 17, 2021
b9055c2
cleanup
0x00b1 Sep 17, 2021
6c630c5
Merge branch 'main' into issues/3960
0x00b1 Sep 17, 2021
540c6a1
updated import
0x00b1 Sep 17, 2021
8e4fc2f
Merge branch 'main' into issues/3960
0x00b1 Sep 17, 2021
4c78297
use torch
0x00b1 Sep 20, 2021
140e429
Update gallery/plot_repurposing_annotations.py
0x00b1 Sep 20, 2021
8f2cd4a
Update gallery/plot_repurposing_annotations.py
0x00b1 Sep 20, 2021
7252723
Update gallery/plot_repurposing_annotations.py
0x00b1 Sep 20, 2021
26f68af
Merge branch 'main' into issues/3960
0x00b1 Sep 20, 2021
2c2d5dd
Autodoc
0x00b1 Sep 21, 2021
3a91957
use torch instead of numpy in tests
0x00b1 Sep 21, 2021
e24805c
fix build_docs failure
0x00b1 Sep 21, 2021
65404e9
Merge branch 'main' into issues/3960
0x00b1 Sep 21, 2021
6c89be7
Closing quotes.
datumbox Sep 21, 2021
b2a907c
Merge branch 'main' into issues/3960
datumbox Sep 21, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sphinx==3.5.4
sphinx-gallery>=0.9.0
sphinx-copybutton>=0.3.1
matplotlib
numpy
sphinx-copybutton>=0.3.1
sphinx-gallery>=0.9.0
sphinx==3.5.4
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
15 changes: 8 additions & 7 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ torchvision.ops
All operators have native support for TorchScript.


.. autofunction:: nms
.. autofunction:: batched_nms
.. autofunction:: remove_small_boxes
.. autofunction:: clip_boxes_to_image
.. autofunction:: box_convert
.. autofunction:: box_area
.. autofunction:: box_convert
.. autofunction:: box_iou
.. autofunction:: clip_boxes_to_image
.. autofunction:: deform_conv2d
.. autofunction:: generalized_box_iou
.. autofunction:: roi_align
.. autofunction:: masks_to_boxes
.. autofunction:: nms
.. autofunction:: ps_roi_align
.. autofunction:: roi_pool
.. autofunction:: ps_roi_pool
.. autofunction:: deform_conv2d
.. autofunction:: remove_small_boxes
.. autofunction:: roi_align
.. autofunction:: roi_pool
.. autofunction:: sigmoid_focal_loss
.. autofunction:: stochastic_depth

Expand Down
75 changes: 75 additions & 0 deletions gallery/plot_repurposing_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
=======================
Repurposing annotations
=======================

The following example illustrates the operations available in the torchvision.ops module for repurposing object
localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation
methods into bounding boxes used by object detection methods).
"""
import os.path

import PIL.Image
import matplotlib.patches
import matplotlib.pyplot
import numpy
import torch
from torchvision.ops import masks_to_boxes

ASSETS_DIRECTORY = "../test/assets"
0x00b1 marked this conversation as resolved.
Show resolved Hide resolved

matplotlib.pyplot.rcParams["savefig.bbox"] = "tight"

####################################
# Masks
# -----
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape:
#
# (objects, height, width)
#
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape
# of your masks annotation has the following shape:
#
# (4, 224, 224).
#
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
# localization tasks.
#
# Masks to bounding boxes
# ----------------------------------------
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be
# used in methods like Faster RCNN and YOLO.

with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)

for index in range(image.n_frames):
image.seek(index)

frame = numpy.array(image)

masks[index] = torch.tensor(frame)

bounding_boxes = masks_to_boxes(masks)

figure = matplotlib.pyplot.figure()

a = figure.add_subplot(121)
b = figure.add_subplot(122)

labeled_image = torch.sum(masks, 0)

a.imshow(labeled_image)
b.imshow(labeled_image)

for bounding_box in bounding_boxes:
x0, y0, x1, y1 = bounding_box

rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none")

b.add_patch(rectangle)

a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
Binary file added test/assets/labeled_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/masks.tiff
Binary file not shown.
34 changes: 34 additions & 0 deletions test/test_masks_to_boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os.path

import PIL.Image
import numpy
import torch

from torchvision.ops import masks_to_boxes

ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")


def test_masks_to_boxes():
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int)

for index in range(image.n_frames):
image.seek(index)

frame = numpy.array(image)

masks[index] = torch.tensor(frame)

expected = torch.tensor(
[[127, 2, 165, 40],
[2, 50, 44, 92],
[56, 63, 98, 100],
[139, 68, 175, 104],
[160, 112, 198, 145],
[49, 138, 99, 182],
[108, 148, 152, 213]],
dtype=torch.int32
)

torch.testing.assert_close(masks_to_boxes(masks), expected)
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou, \
masks_to_boxes
from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d
from .roi_align import roi_align, RoIAlign
Expand Down
32 changes: 32 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,35 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
areai = whi[:, :, 0] * whi[:, :, 1]

return iou - (areai - union) / areai


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks

Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of
masks and (H, W) are the spatial dimensions.

Returns:
Tensor[N, 4]: bounding boxes
"""
if masks.numel() == 0:
return torch.zeros((0, 4))

n = masks.shape[0]

bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int)
Copy link
Contributor

@oke-aditya oke-aditya Sep 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thought was dtype should be torch.float. Since all other ops follow float dtype.

cc @datumbox @NicolasHug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, also the above zeros needs to have a device:
torch.zeros((0, 4), device=masks.device)

Could you please send a PR that fixes these 2 issues? The rest of the doc/test improvements discussed here can happen on a separate PR.


for index, mask in enumerate(masks):
y, x = torch.where(masks[index] != 0)

bounding_boxes[index, 0] = torch.min(x)
bounding_boxes[index, 1] = torch.min(y)
bounding_boxes[index, 2] = torch.max(x)
bounding_boxes[index, 3] = torch.max(y)

return bounding_boxes