Skip to content

Commit

Permalink
Add dtype argument when calling media.data (#1546)
Browse files Browse the repository at this point in the history
<!-- Contributing guide:
https://github.com/openvinotoolkit/datumaro/blob/develop/CONTRIBUTING.md
-->

### Summary

<!--
Resolves #111 and #222.
Depends on #1000 (for series of dependent commits).

This PR introduces this capability to make the project better in this
and that.

- Added this feature
- Removed that feature
- Fixed the problem #1234
-->

### How to test
<!-- Describe the testing procedure for reviewers, if changes are
not fully covered by unit tests or manual testing can be complicated.
-->

### Checklist
<!-- Put an 'x' in all the boxes that apply -->
- [ ] I have added unit tests to cover my changes.​
- [ ] I have added integration tests to cover my changes.​
- [ ] I have added the description of my changes into
[CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​
- [ ] I have updated the
[documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs)
accordingly

### License

- [ ] I submit _my code changes_ under the same [MIT
License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE)
that covers the project.
  Feel free to contact the maintainers if that's a concern.
- [ ] I have updated the license header for each file (see an example
below).

```python
# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
```
  • Loading branch information
wonjuleee authored Jun 28, 2024
1 parent d00d0cf commit 6a92276
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1492>)
- Pass Keyword Argument to TabularDataBase
(<https://github.com/openvinotoolkit/datumaro/pull/1522>)
- Enable dtype argument when calling media.data
(<https://github.com/openvinotoolkit/datumaro/pull/1546>)

### Bug fixes
- Preserve end_frame information of a video when it is zero.
Expand Down
29 changes: 26 additions & 3 deletions src/datumaro/components/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
from copy import deepcopy
from enum import IntEnum
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -39,6 +40,7 @@
copyto_image,
decode_image,
lazy_image,
load_image,
save_image,
)

Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(
f"{self.__class__.__name__}.from_numpy(), {self.__class__.__name__}.from_bytes())."
)
super().__init__(*args, **kwargs)
self._dtype = np.uint8

if ext is not None:
if not ext.startswith("."):
Expand Down Expand Up @@ -322,6 +325,8 @@ def data(self) -> Optional[np.ndarray]:
if not self.has_data:
return None

if self.__data._dtype != self._dtype:
self.__data._loader = partial(load_image, dtype=self._dtype)
data = self.__data()

if self._size is None and data is not None:
Expand Down Expand Up @@ -368,6 +373,11 @@ def set_crypter(self, crypter: Crypter):
if isinstance(self.__data, lazy_image):
self.__data._crypter = crypter

def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]:
"""Get image data with a specific data type"""
self._dtype = dtype
return self.data


class ImageFromData(FromDataMixin, Image):
def save(
Expand Down Expand Up @@ -400,8 +410,8 @@ def data(self) -> Optional[np.ndarray]:

data = super().data

if isinstance(data, np.ndarray) and data.dtype != np.uint8:
data = np.clip(data, 0.0, 255.0).astype(np.uint8)
if isinstance(data, np.ndarray) and data.dtype != self._dtype:
data = np.clip(data, 0.0, 255.0).astype(self._dtype)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
Expand All @@ -413,6 +423,11 @@ def has_size(self) -> bool:
"""Indicates that size info is cached and won't require image loading"""
return self._size is not None or isinstance(self._data, np.ndarray)

def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]:
"""Get image data with a specific data type"""
self._dtype = dtype
return self.data


class ImageFromBytes(ImageFromData):
_FORMAT_MAGICS = (
Expand Down Expand Up @@ -446,13 +461,21 @@ def data(self) -> Optional[np.ndarray]:
data = super().data

if isinstance(data, bytes):
data = decode_image(data, dtype=np.uint8)
data = decode_image(data, dtype=self._dtype)
if self._size is None and data is not None:
if not 2 <= data.ndim <= 3:
raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.")
self._size = tuple(map(int, data.shape[:2]))
return data

def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]:
"""Get image data with a specific data type"""

if dtype != np.uint8:
raise ValueError("ImageFromBytes only support `dtype=np.uint8`.")
self._dtype = dtype
return self.data


class VideoFrame(ImageFromNumpy):
_type = MediaType.VIDEO_FRAME
Expand Down
28 changes: 20 additions & 8 deletions src/datumaro/util/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class ImageColorChannel(Enum):
COLOR_BGR = 1
COLOR_RGB = 2

def decode_by_cv2(self, image_bytes: bytes) -> np.ndarray:
def decode_by_cv2(self, image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray:
"""Convert image color channel for OpenCV image (np.ndarray)."""
image_buffer = np.frombuffer(image_bytes, dtype=np.uint8)
image_buffer = np.frombuffer(image_bytes, dtype=dtype)

if self == ImageColorChannel.UNCHANGED:
return cv2.imdecode(image_buffer, cv2.IMREAD_UNCHANGED)
Expand Down Expand Up @@ -283,15 +283,26 @@ def encode_image(image: np.ndarray, ext: str, dtype: DTypeLike = np.uint8, **kwa
raise NotImplementedError()


def decode_image(image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray:
def decode_image(image_bytes: bytes, dtype: np.dtype = np.uint8) -> np.ndarray:
ctx_color_scale = IMAGE_COLOR_CHANNEL.get()

if IMAGE_BACKEND.get() == ImageBackend.cv2:
image = ctx_color_scale.decode_by_cv2(image_bytes)
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
image = ctx_color_scale.decode_by_pil(image_bytes)
if np.issubdtype(dtype, np.floating):
# PIL doesn't support floating point image loading
# CV doesn't support floating point image with color channel setting (IMREAD_COLOR)
with decode_image_context(
image_backend=ImageBackend.cv2, image_color_channel=ImageColorChannel.UNCHANGED
):
image = ctx_color_scale.decode_by_cv2(image_bytes, dtype=dtype)
image = image[..., ::-1]
if ctx_color_scale == ImageColorChannel.COLOR_BGR:
image = image[..., ::-1]
else:
raise NotImplementedError()
if IMAGE_BACKEND.get() == ImageBackend.cv2:
image = ctx_color_scale.decode_by_cv2(image_bytes)
elif IMAGE_BACKEND.get() == ImageBackend.PIL:
image = ctx_color_scale.decode_by_pil(image_bytes)
else:
raise NotImplementedError()

image = image.astype(dtype)

Expand Down Expand Up @@ -376,6 +387,7 @@ def __init__(
assert isinstance(cache, (ImageCache, bool))
self._cache = cache
self._crypter = crypter
self._dtype = dtype

def __call__(self) -> np.ndarray:
image = None
Expand Down
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ pytest-stress
pytest-html
coverage
pytest-csv

tifffile
20 changes: 20 additions & 0 deletions tests/unit/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,26 @@ def test_ext_detection_failure(self):
image = Image.from_bytes(data=image_bytes)
self.assertEqual(image.ext, None)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_floating_image_from_numpy(self):
image_float = np.random.rand(32, 32, 3).astype(np.float16) * 255.0
media = Image.from_numpy(image_float)
data = media.get_data_as_dtype(dtype=np.float16)
self.assertTrue(np.all(image_float == data))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_floating_image_from_file(self):
import tifffile

with TestDir() as test_dir:
image_float = np.random.rand(32, 32, 3).astype(np.float32) * 255.0
image_path = osp.join(test_dir, "floating_image.tiff")
tifffile.imwrite(image_path, image_float)

media = Image.from_file(image_path)
data = media.get_data_as_dtype(dtype=np.float32)
self.assertTrue(np.all(image_float == data))


class RoIImageTest(TestCase):
def _test_ctors(self, img_ctor, args_list, test_dir, is_bytes=False):
Expand Down

0 comments on commit 6a92276

Please sign in to comment.