-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
masks_to_bounding_boxes
op
#4290
Changes from 34 commits
cf51379
c67e035
3830dd1
926d444
f777416
cd46aa7
712131e
b6f5c42
b555c68
fc26f3a
c4d3045
4589951
6b19d67
c6c89ec
16a99a9
7115320
f4796d2
a070133
0131db3
0a23bcf
db8fb7b
f7a2c1e
c7dfcdf
5e6198a
7c78271
b9055c2
6c630c5
540c6a1
8e4fc2f
4c78297
140e429
8f2cd4a
7252723
26f68af
2c2d5dd
3a91957
e24805c
65404e9
6c89be7
b2a907c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
sphinx==3.5.4 | ||
sphinx-gallery>=0.9.0 | ||
sphinx-copybutton>=0.3.1 | ||
matplotlib | ||
numpy | ||
sphinx-copybutton>=0.3.1 | ||
sphinx-gallery>=0.9.0 | ||
sphinx==3.5.4 | ||
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
======================= | ||
Repurposing annotations | ||
======================= | ||
|
||
The following example illustrates the operations available in :ref:`the torchvision.ops module <ops>` for repurposing | ||
object localization annotations for different tasks (e.g. transforming masks used by instance and panoptic | ||
segmentation methods into bounding boxes used by object detection methods). | ||
""" | ||
import os.path | ||
|
||
import PIL.Image | ||
import matplotlib.patches | ||
import matplotlib.pyplot | ||
import numpy | ||
import torch | ||
from torchvision.ops import masks_to_boxes | ||
|
||
ASSETS_DIRECTORY = "../test/assets" | ||
0x00b1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
matplotlib.pyplot.rcParams["savefig.bbox"] = "tight" | ||
|
||
#################################### | ||
# Masks | ||
# ----- | ||
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, | ||
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape: | ||
# | ||
# (objects, height, width) | ||
# | ||
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly | ||
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape | ||
# of your masks annotation has the following shape: | ||
# | ||
# (4, 224, 224). | ||
# | ||
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object | ||
# localization tasks. | ||
# | ||
# Masks to bounding boxes | ||
# ---------------------------------------- | ||
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be | ||
# used in methods like Faster RCNN and YOLO. | ||
|
||
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: | ||
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int) | ||
|
||
for index in range(image.n_frames): | ||
image.seek(index) | ||
|
||
frame = numpy.array(image) | ||
|
||
masks[index] = torch.tensor(frame) | ||
|
||
bounding_boxes = masks_to_boxes(masks) | ||
|
||
figure = matplotlib.pyplot.figure() | ||
|
||
a = figure.add_subplot(121) | ||
b = figure.add_subplot(122) | ||
|
||
labeled_image = torch.sum(masks, 0) | ||
|
||
a.imshow(labeled_image) | ||
b.imshow(labeled_image) | ||
|
||
for bounding_box in bounding_boxes: | ||
x0, y0, x1, y1 = bounding_box | ||
|
||
rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none") | ||
|
||
b.add_patch(rectangle) | ||
|
||
a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||
b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os.path | ||
|
||
import PIL.Image | ||
import numpy | ||
import torch | ||
|
||
from torchvision.ops import masks_to_boxes | ||
|
||
ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | ||
|
||
|
||
def test_masks_to_boxes(): | ||
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: | ||
frames = numpy.zeros((image.n_frames, image.height, image.width), int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @0x00b1 Could make the same changes you did on the galleries also here? Aka replace numpy with torch. |
||
|
||
for index in range(image.n_frames): | ||
image.seek(index) | ||
|
||
frames[index] = numpy.array(image) | ||
|
||
masks = torch.tensor(frames) | ||
|
||
expected = torch.tensor( | ||
[[127, 2, 165, 40], | ||
[2, 50, 44, 92], | ||
[56, 63, 98, 100], | ||
[139, 68, 175, 104], | ||
[160, 112, 198, 145], | ||
[49, 138, 99, 182], | ||
[108, 148, 152, 213]], | ||
dtype=torch.int32 | ||
) | ||
|
||
torch.testing.assert_close(masks_to_boxes(masks), expected) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,3 +297,35 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: | |
areai = whi[:, :, 0] * whi[:, :, 1] | ||
|
||
return iou - (areai - union) / areai | ||
|
||
|
||
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute the bounding boxes around the provided masks | ||
|
||
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||
``0 <= x1 < x2`` and ``0 <= y1 < y2. | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
masks (Tensor[N, H, W]): masks to transform where N is the number of | ||
masks and (H, W) are the spatial dimensions. | ||
|
||
Returns: | ||
Tensor[N, 4]: bounding boxes | ||
""" | ||
if masks.numel() == 0: | ||
return torch.zeros((0, 4)) | ||
|
||
n = masks.shape[0] | ||
|
||
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My initial thought was dtype should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, also the above zeros needs to have a device: Could you please send a PR that fixes these 2 issues? The rest of the doc/test improvements discussed here can happen on a separate PR. |
||
|
||
for index, mask in enumerate(masks): | ||
y, x = torch.where(masks[index] != 0) | ||
|
||
bounding_boxes[index, 0] = torch.min(x) | ||
bounding_boxes[index, 1] = torch.min(y) | ||
bounding_boxes[index, 2] = torch.max(x) | ||
bounding_boxes[index, 3] = torch.max(y) | ||
|
||
return bounding_boxes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some debugging I found out the reason for
build_docs
CI failure. The problem istorchvision.ops
does not have a nice index on right side (basically a html link to #ops like transforms has). This causes CI failure.We need to remove the ref, and it will work fine. This is slightly hacky fix, but works fine.
I tried running it locally. I could build the gallery example. It looks nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I appreciate the debugging.