Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
add instance segmentation per_sample_transform (#1353)
Browse files Browse the repository at this point in the history
Co-authored-by: Anton Shevtsov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: fatih <[email protected]>
Co-authored-by: karthikrangasai <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
7 people authored Jun 28, 2022
1 parent 0001838 commit 0ae31c8
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where size of mask for instance segmentation doesn't match to size of original image. ([#1353](https://github.com/PyTorchLightning/lightning-flash/pull/1353))

- Fixed image classification data `show_train_batch` for subplots with rows > 1. ([#1339](https://github.com/PyTorchLightning/lightning-flash/pull/1315))

- Fixed support for all the versions (including the latest and older) of `baal`. ([#1315](https://github.com/PyTorchLightning/lightning-flash/pull/1315))
Expand Down
20 changes: 18 additions & 2 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -35,12 +35,18 @@
ImageRecordComponent,
InstancesLabelsRecordComponent,
KeyPointsRecordComponent,
RecordComponent,
RecordIDRecordComponent,
)
from icevision.data.prediction import Prediction
from icevision.tfms import A
else:
MaskArray = object
RecordComponent = object

class tasks:
common = object


if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0:
from icevision.core.record_components import InstanceMasksRecordComponent
Expand All @@ -53,6 +59,13 @@ def _split_mask_array(mask_array: MaskArray) -> List[MaskArray]:
return [MaskArray(mask) for mask in mask_array.data]


class OriginalSizeRecordComponent(RecordComponent):
def __init__(self, original_size: Optional[Tuple[int, int]], task=tasks.common):
super().__init__(task=task)
# original_size: (h, w)
self.original_size: Optional[Tuple[int, int]] = original_size


def to_icevision_record(sample: Dict[str, Any]):
record = BaseRecord([])

Expand All @@ -79,6 +92,8 @@ def to_icevision_record(sample: Dict[str, Any]):
image = sample[DataKeys.INPUT]
image = image.permute(1, 2, 0).numpy() if isinstance(image, torch.Tensor) else image
input_component.set_img(image)

record.add_component(OriginalSizeRecordComponent(metadata.get("size", image.shape[:2])))
record.add_component(input_component)

if DataKeys.TARGET in sample:
Expand Down Expand Up @@ -205,7 +220,8 @@ def from_icevision_detection(record: "BaseRecord"):
def from_icevision_record(record: "BaseRecord"):
sample = {
DataKeys.METADATA: {
"size": (record.height, record.width),
"size": getattr(record, "original_size", (record.height, record.width)),
"output_size": (record.height, record.width),
}
}

Expand Down
2 changes: 1 addition & 1 deletion flash/image/detection/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]
if DataKeys.METADATA not in sample:
raise ValueError("sample requires DataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output.")

height, width = sample[DataKeys.METADATA]["size"]
height, width = sample[DataKeys.METADATA]["output_size"]

detections = []

Expand Down
14 changes: 10 additions & 4 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union

from torch import tensor

from flash.core.data.data_module import DataModule
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.utilities.sort import sorted_alphanumeric
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _KORNIA_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE

Expand All @@ -32,16 +34,20 @@
VOCMaskParser = object
Parser = object

if _KORNIA_AVAILABLE:
import kornia as K


# Skip doctests if requirements aren't available
if not _ICEVISION_AVAILABLE:
__doctest_skip__ = ["InstanceSegmentationData", "InstanceSegmentationData.*"]


class InstanceSegmentationOutputTransform(OutputTransform):
@staticmethod
def uncollate(batch: Any) -> Any:
return batch[DataKeys.PREDS]
def per_sample_transform(self, sample: Any) -> Any:
resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="nearest")
sample[DataKeys.PREDS]["masks"] = [resize(tensor(mask)) for mask in sample[DataKeys.PREDS]["masks"]]
return sample[DataKeys.PREDS]


class InstanceSegmentationData(DataModule):
Expand Down
1 change: 1 addition & 0 deletions tests/image/detection/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_serialize_fiftyone():
DataKeys.METADATA: {
"filepath": "something",
"size": (100, 100),
"output_size": (100, 100),
},
}

Expand Down
31 changes: 31 additions & 0 deletions tests/image/instance_segmentation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE
from flash.image.instance_segmentation import InstanceSegmentationData
from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform
from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset


Expand Down Expand Up @@ -43,3 +46,31 @@ def test_image_detector_data_from_folders(tmpdir):
data = next(iter(datamodule.predict_dataloader()))
sample = data[0]
assert sample[DataKeys.INPUT].shape == (128, 128, 3)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
def test_instance_segmentation_output_transform():

sample = {
DataKeys.INPUT: torch.rand(3, 224, 224),
DataKeys.PREDS: {
"bboxes": [
{"xmin": 10, "ymin": 10, "width": 15, "height": 15},
{"xmin": 30, "ymin": 30, "width": 40, "height": 40},
],
"labels": [0, 1],
"masks": [
np.random.randint(2, size=(128, 128), dtype=np.uint8),
np.random.randint(2, size=(128, 128), dtype=np.uint8),
],
"scores": [0.5, 0.5],
},
DataKeys.METADATA: {"size": (224, 224)},
}

output_transform_cls = InstanceSegmentationOutputTransform()
data = output_transform_cls.per_sample_transform(sample)

assert data["masks"][0].size() == (224, 224)
assert data["masks"][1].size() == (224, 224)

0 comments on commit 0ae31c8

Please sign in to comment.