Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change image default dtype from float32 to uint8 #1175

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Enhancements
- Enhance Datumaro data format stream importer performance
(<https://github.com/openvinotoolkit/datumaro/pull/1153>)
- Change image default dtype from float32 to uint8
(<https://github.com/openvinotoolkit/datumaro/pull/1175>)

### Bug fixes
- Fix errata in the voc document. Color values in the labelmap.txt should be separated by commas, not colons.
Expand Down
14 changes: 6 additions & 8 deletions src/datumaro/components/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(

@property
def data(self) -> Optional[np.ndarray]:
"""Image data in BGRA HWC [0; 255] (float) format"""
"""Image data in BGRA HWC [0; 255] (uint8) format"""

if not self.has_data:
return None
Expand Down Expand Up @@ -375,12 +375,12 @@ def __init__(

@property
def data(self) -> Optional[np.ndarray]:
"""Image data in BGRA HWC [0; 255] (float) format"""
"""Image data in BGRA HWC [0; 255] (uint8) format"""

data = super().data

if isinstance(data, np.ndarray):
data = data.astype(np.float32)
if isinstance(data, np.ndarray) and data.dtype != np.uint8:
data = np.clip(data, 0.0, 255.0).astype(np.uint8)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
Expand Down Expand Up @@ -420,14 +420,12 @@ def _guess_ext(cls, data: bytes) -> Optional[str]:

@property
def data(self) -> Optional[np.ndarray]:
"""Image data in BGRA HWC [0; 255] (float) format"""
"""Image data in BGRA HWC [0; 255] (uint8) format"""

data = super().data

if isinstance(data, bytes):
data = decode_image(data)
if isinstance(data, np.ndarray):
data = data.astype(np.float32)
data = decode_image(data, dtype=np.uint8)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
Expand Down
3 changes: 0 additions & 3 deletions src/datumaro/plugins/framework_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def __init__(
def __getitem__(self, idx):
image, label = self._gen_item(idx)

if image.dtype == np.uint8 or image.max() > 1:
image = image.astype(np.float32) / 255

if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/operations/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

@pytest.fixture
def fxt_image_dataset_expected_mean_std():
np.random.seed(3003)
expected_mean = [100, 50, 150]
expected_std = [20, 50, 10]
expected_std = [2, 1, 3]

return expected_mean, expected_std

Expand Down Expand Up @@ -90,9 +91,9 @@ def test_image_stats(
actual_std = actual["subsets"]["default"]["image std"][::-1]

for em, am in zip(expected_mean, actual_mean):
assert am == pytest.approx(em, 1e-2)
assert am == pytest.approx(em, 5e-1)
for estd, astd in zip(expected_std, actual_std):
assert astd == pytest.approx(estd, 1e-2)
assert astd == pytest.approx(estd, 1e-1)

@mark_requirement(Requirements.DATUM_BUG_873)
def test_invalid_media_type(
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_framework_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def test_can_convert_torch_framework(
label = np.sum(masks, axis=0, dtype=np.uint8)

if fxt_convert_kwargs.get("transform", None):
assert np.array_equal(image, dm_torch_item[0].reshape(5, 5, 3).numpy())
actual = dm_torch_item[0].permute(1, 2, 0).mul(255.0).to(torch.uint8).numpy()
assert np.array_equal(image, actual)
else:
assert np.array_equal(image, dm_torch_item[0])

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@
class TestOperations(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_mean_std(self):
np.random.seed(3000)
expected_mean = [100, 50, 150]
expected_std = [20, 50, 10]
expected_std = [2, 1, 3]

dataset = Dataset.from_iterable(
[
Expand All @@ -62,9 +63,9 @@ def test_mean_std(self):
actual_mean, actual_std = mean_std(dataset)

for em, am in zip(expected_mean, actual_mean):
self.assertAlmostEqual(em, am, places=0)
assert np.allclose(em, am, atol=0.6)
for estd, astd in zip(expected_std, actual_std):
self.assertAlmostEqual(estd, astd, places=0)
assert np.allclose(estd, astd, atol=0.1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_stats(self):
Expand Down
Loading