Skip to content

Commit

Permalink
add test for DetectionsMask import/export
Browse files Browse the repository at this point in the history
  • Loading branch information
Kemal Eren committed Jun 18, 2024
1 parent be5ac20 commit e392b08
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
1 change: 1 addition & 0 deletions fiftyone/__public__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
Classifications,
Detection,
Detections,
DetectionsMask,
Polyline,
Polylines,
Keypoint,
Expand Down
4 changes: 2 additions & 2 deletions fiftyone/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def _to_image(mask, extension):
if mask.dtype not in (np.uint8, np.uint16, np.uint32, np.uint64):
raise ValueError(
f"object detection mask dtype is '{mask.dtype}',"
f" but only uint dtypes are supported."
f" but only numpy uint dtypes are supported."
)

if mask.ndim != 2:
Expand All @@ -793,7 +793,7 @@ def _to_image(mask, extension):
maxval = mask.max()
if extension == ".png" and maxval >= 2**16:
raise ValueError(
f"max value of {maxval} exceeds upper limit of of uint16 for PNG file"
f"max value of {maxval} exceeds upper limit of np.uint16 for PNG file"
)
dtype = DetectionsMask._get_uint_dtype(maxval)
return mask.astype(dtype)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_version():
"starlette>=0.24.0",
"strawberry-graphql==0.138.1",
"tabulate",
"tifffile",
"xmltodict",
"universal-analytics-python3>=1.0.1,<2",
# internal packages
Expand Down
35 changes: 35 additions & 0 deletions tests/unittests/label_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
|
"""
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory

from bson import Binary, ObjectId
import numpy as np
Expand Down Expand Up @@ -601,6 +603,39 @@ def test_transform_mask(self):
rgb_to_rgb = focl._transform_mask(int_to_rgb, targets_map)
nptest.assert_array_equal(rgb_to_rgb, np.zeros((3, 3, 3), dtype=int))

@drop_datasets
def test_detections_mask_io(self):
for dtype in (np.uint8, np.uint16, np.uint32, np.uint64):
for extension in ("png", "tif"):
for test_max in (False, True):
mask = np.arange(9, dtype=dtype).reshape((3, 3))
png_info = np.iinfo(np.uint16)
info = np.iinfo(dtype)
if test_max:
mask[-1, -1] = info.max

export_should_fail = (
test_max
and extension == "png"
and mask.max() > png_info.max
)

dm = fo.DetectionsMask(mask=mask)

with TemporaryDirectory() as tmp_dir:
outpath = Path(tmp_dir) / f"mask.{extension}"
if export_should_fail:
self.assertRaises(
ValueError,
dm.export_mask,
outpath,
update=True,
)
else:
dm.export_mask(outpath, update=True)
imported_mask = dm.get_mask()
nptest.assert_array_equal(mask, imported_mask)


class LabelUtilsTests(unittest.TestCase):
@drop_datasets
Expand Down

0 comments on commit e392b08

Please sign in to comment.