Skip to content

Commit

Permalink
Rewrite test and fix masks_to_boxes implementation (#4469)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
oke-aditya and NicolasHug authored Sep 24, 2021
1 parent 021df7a commit cdb6fba
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
34 changes: 0 additions & 34 deletions test/test_masks_to_boxes.py

This file was deleted.

34 changes: 34 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pytest

import numpy as np
import os

from PIL import Image
import torch
from functools import lru_cache
from torch import Tensor
Expand Down Expand Up @@ -1000,6 +1002,38 @@ def gen_iou_check(box, expected, tolerance=1e-4):
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)


class TestMasksToBoxes:
def test_masks_box(self):
def masks_box_check(masks, expected, tolerance=1e-4):
out = ops.masks_to_boxes(masks)
assert out.dtype == torch.float
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

# Check for int type boxes.
def _get_image():
assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
mask_path = os.path.join(assets_directory, "masks.tiff")
image = Image.open(mask_path)
return image

def _create_masks(image, masks):
for index in range(image.n_frames):
image.seek(index)
frame = np.array(image)
masks[index] = torch.tensor(frame)

return masks

expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104],
[160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float)

image = _get_image()
for dtype in [torch.float16, torch.float32, torch.float64]:
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
masks = _create_masks(image, masks)
masks_box_check(masks, expected)


class TestStochasticDepth:
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8])
@pytest.mark.parametrize('mode', ["batch", "row"])
Expand Down
12 changes: 6 additions & 6 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,24 +301,24 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:

def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks
Compute the bounding boxes around the provided masks.
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of
masks and (H, W) are the spatial dimensions.
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
and (H, W) are the spatial dimensions.
Returns:
Tensor[N, 4]: bounding boxes
"""
if masks.numel() == 0:
return torch.zeros((0, 4))
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)

n = masks.shape[0]

bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int)
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)

for index, mask in enumerate(masks):
y, x = torch.where(masks[index] != 0)
Expand Down

0 comments on commit cdb6fba

Please sign in to comment.