Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

fix channel dim selection on segmentation target #1509

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0]
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[0, :, :]
return super().load_sample(sample)


Expand Down
76 changes: 72 additions & 4 deletions tests/image/segmentation/test_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Callable, Dict, List, Tuple

import numpy as np
import pytest
import torch

from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utilities.loading import load_image
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
_IMAGE_AVAILABLE,
Expand All @@ -16,6 +19,7 @@
_PIL_AVAILABLE,
)
from flash.image import SemanticSegmentation, SemanticSegmentationData
from flash.image.segmentation.input import SemanticSegmentationFilesInput

if _PIL_AVAILABLE:
from PIL import Image
Expand Down Expand Up @@ -43,12 +47,22 @@ def _rand_labels(size: Tuple[int, int], num_classes: int):
return Image.fromarray(data.astype(np.uint8))


def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int):
def create_random_data(
image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int
) -> Tuple[List[Image.Image], List[Image.Image]]:
imgs = []
for img_file in image_files:
_rand_image(size).save(img_file)
img = _rand_image(size)
img.save(img_file)
imgs.append(img)

labels = []
for label_file in label_files:
_rand_labels(size, num_classes).save(label_file)
label = _rand_labels(size, num_classes)
label.save(label_file)
labels.append(label)

return imgs, labels


class TestSemanticSegmentationData:
Expand All @@ -58,6 +72,60 @@ def test_smoke():
dm = SemanticSegmentationData(batch_size=1)
assert dm is not None

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_identity(tmpdir):
class IdentityTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys(
DataKeys.INPUT,
np.array,
)

def per_batch_transform(self) -> Callable:
return lambda x: x

tmp_dir = Path(tmpdir)

# create random dummy data

os.makedirs(str(tmp_dir / "images"))
os.makedirs(str(tmp_dir / "targets"))

images = [
str(tmp_dir / "images" / "img1.png")
]

targets = [
str(tmp_dir / "targets" / "img1.png")
]

num_classes: int = 2
img_size: Tuple[int, int] = (128, 128)
images_data, targets_data = create_random_data(images, targets, img_size, num_classes)

# instantiate the data module

dm = SemanticSegmentationData.from_files(
test_files=images,
test_targets=targets,
batch_size=1,
num_workers=0,
num_classes=num_classes,
transform=IdentityTransform(),
)

assert dm is not None
assert dm.test_dataloader() is not None

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (1, 128, 128, 3)
assert labels.shape == (1, 128, 128)
assert torch.allclose(imgs, torch.from_numpy(np.array(images_data[0])))
assert torch.allclose(labels, torch.from_numpy(np.array(targets_data[0]))[:, :, 0])

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders(tmpdir):
Expand Down