From e392b089fa3546c1b3572c88721684d3ddb85e58 Mon Sep 17 00:00:00 2001 From: Kemal Eren Date: Tue, 18 Jun 2024 15:59:16 -0400 Subject: [PATCH] add test for DetectionsMask import/export --- fiftyone/__public__.py | 1 + fiftyone/core/labels.py | 4 ++-- setup.py | 1 + tests/unittests/label_tests.py | 35 ++++++++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/fiftyone/__public__.py b/fiftyone/__public__.py index e932c6c9c5..9811a15599 100644 --- a/fiftyone/__public__.py +++ b/fiftyone/__public__.py @@ -93,6 +93,7 @@ Classifications, Detection, Detections, + DetectionsMask, Polyline, Polylines, Keypoint, diff --git a/fiftyone/core/labels.py b/fiftyone/core/labels.py index 0e92c17c23..74fab73d5b 100644 --- a/fiftyone/core/labels.py +++ b/fiftyone/core/labels.py @@ -783,7 +783,7 @@ def _to_image(mask, extension): if mask.dtype not in (np.uint8, np.uint16, np.uint32, np.uint64): raise ValueError( f"object detection mask dtype is '{mask.dtype}'," - f" but only uint dtypes are supported." + f" but only numpy uint dtypes are supported." ) if mask.ndim != 2: @@ -793,7 +793,7 @@ def _to_image(mask, extension): maxval = mask.max() if extension == ".png" and maxval >= 2**16: raise ValueError( - f"max value of {maxval} exceeds upper limit of of uint16 for PNG file" + f"max value of {maxval} exceeds upper limit of np.uint16 for PNG file" ) dtype = DetectionsMask._get_uint_dtype(maxval) return mask.astype(dtype) diff --git a/setup.py b/setup.py index a7a031d9ab..afdfbd75b4 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ def get_version(): "starlette>=0.24.0", "strawberry-graphql==0.138.1", "tabulate", + "tifffile", "xmltodict", "universal-analytics-python3>=1.0.1,<2", # internal packages diff --git a/tests/unittests/label_tests.py b/tests/unittests/label_tests.py index 5832e69f7e..9bc47eb883 100644 --- a/tests/unittests/label_tests.py +++ b/tests/unittests/label_tests.py @@ -6,6 +6,8 @@ | """ import unittest +from pathlib import Path +from tempfile import TemporaryDirectory from bson import Binary, ObjectId import numpy as np @@ -601,6 +603,39 @@ def test_transform_mask(self): rgb_to_rgb = focl._transform_mask(int_to_rgb, targets_map) nptest.assert_array_equal(rgb_to_rgb, np.zeros((3, 3, 3), dtype=int)) + @drop_datasets + def test_detections_mask_io(self): + for dtype in (np.uint8, np.uint16, np.uint32, np.uint64): + for extension in ("png", "tif"): + for test_max in (False, True): + mask = np.arange(9, dtype=dtype).reshape((3, 3)) + png_info = np.iinfo(np.uint16) + info = np.iinfo(dtype) + if test_max: + mask[-1, -1] = info.max + + export_should_fail = ( + test_max + and extension == "png" + and mask.max() > png_info.max + ) + + dm = fo.DetectionsMask(mask=mask) + + with TemporaryDirectory() as tmp_dir: + outpath = Path(tmp_dir) / f"mask.{extension}" + if export_should_fail: + self.assertRaises( + ValueError, + dm.export_mask, + outpath, + update=True, + ) + else: + dm.export_mask(outpath, update=True) + imported_mask = dm.get_mask() + nptest.assert_array_equal(mask, imported_mask) + class LabelUtilsTests(unittest.TestCase): @drop_datasets