Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: add tiffile serializer #425

Merged
merged 12 commits into from
Nov 28, 2024
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ filelock
numpy
boto3
requests
tifffile
23 changes: 22 additions & 1 deletion src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import numpy as np
import tifffile
import torch
from lightning_utilities.core.imports import RequirementCache

Expand Down Expand Up @@ -387,13 +388,33 @@ def can_serialize(self, data: float) -> bool:
return isinstance(data, float)


class TIFFSerializer(Serializer):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Serializer for TIFF files using tifffile."""

def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
if not isinstance(item, str) or not os.path.isfile(item):
raise ValueError(f"The item to serialize must be a valid file path. Received: {item}")

# Read the TIFF file as bytes
with open(item, "rb") as f:
data = f.read()

return data, None

def deserialize(self, data: bytes) -> Any:
return tifffile.imread(io.BytesIO(data)) # This is a NumPy array

def can_serialize(self, item: Any) -> bool:
return isinstance(item, str) and os.path.isfile(item) and item.lower().endswith((".tif", ".tiff"))


_SERIALIZERS = OrderedDict(
**{
"str": StringSerializer(),
"int": IntegerSerializer(),
"float": FloatSerializer(),
"video": VideoSerializer(),
"tif": FileSerializer(),
"tifffile": TIFFSerializer(),
"file": FileSerializer(),
"pil": PILSerializer(),
"jpeg": JPEGSerializer(),
Expand Down
38 changes: 37 additions & 1 deletion tests/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import os
import random
import sys
import tempfile
from unittest import mock

import numpy as np
import pytest
import tifffile
import torch
from lightning_utilities.core.imports import RequirementCache
from litdata.streaming.serializers import (
Expand All @@ -34,6 +36,7 @@
NumpySerializer,
PILSerializer,
TensorSerializer,
TIFFSerializer,
VideoSerializer,
_get_serializers,
)
Expand All @@ -46,6 +49,7 @@ def seed_everything(random_seed):


_PIL_AVAILABLE = RequirementCache("PIL")
_TIFFFILE_AVAILABLE = RequirementCache("tifffile")


def test_serializers():
Expand All @@ -55,7 +59,7 @@ def test_serializers():
"int",
"float",
"video",
"tif",
"tifffile",
"file",
"pil",
"jpeg",
Expand Down Expand Up @@ -265,3 +269,35 @@ def test_deserialize_empty_no_header_tensor():
serializer.setup(name)
new_t = serializer.deserialize(data)
assert torch.equal(t, new_t)


@pytest.mark.skipif(not _TIFFFILE_AVAILABLE, reason="Requires: ['tifffile']")
def test_tiff_serializer():
serializer = TIFFSerializer()

# Create a synthetic multispectral image
height, width, bands = 28, 28, 12
np_data = np.random.randint(0, 65535, size=(height, width, bands), dtype=np.uint16)

with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp_file:
tifffile.imwrite(tmp_file.name, np_data)
file_path = tmp_file.name

# Test can_serialize
assert serializer.can_serialize(file_path)

# Serialize
data, _ = serializer.serialize(file_path)
assert isinstance(data, bytes)

# Deserialize
deserialized_data = serializer.deserialize(data)
assert isinstance(deserialized_data, np.ndarray)
assert deserialized_data.shape == (height, width, bands)
assert deserialized_data.dtype == np.uint16

# Validate data content
assert np.array_equal(np_data, deserialized_data)

# Clean up
os.remove(file_path)
Loading