From dd786be4879994d4ef5e71014063037721d746a5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 18:27:37 +0000 Subject: [PATCH 1/6] Fix od / kd/ is examples and add to CI --- flash/image/detection/data.py | 4 ++-- flash/image/instance_segmentation/data.py | 4 ++-- flash/image/keypoint_detection/data.py | 4 ++-- flash_examples/instance_segmentation.py | 2 +- .../integrations/fiftyone/object_detection.py | 2 +- flash_examples/keypoint_detection.py | 2 +- tests/examples/test_integrations.py | 7 +++++++ tests/examples/test_scripts.py | 20 ++++++++++++++++++- 8 files changed, 35 insertions(+), 10 deletions(-) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9425dd01a8..bba8a852f9 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState @@ -58,7 +58,7 @@ def from_icedata( val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - parser: Type[Parser] = Parser, + parser: Optional[Callable, Type[Parser]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index c08b2e7f2d..6f2c1dad4f 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState @@ -56,7 +56,7 @@ def from_icedata( val_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform, - parser: Optional[Type[Parser]] = Parser, + parser: Optional[Callable, Type[Parser]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index c3b8ef761b..1b544fd33f 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState @@ -47,7 +47,7 @@ def from_icedata( val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - parser: Optional[Type[Parser]] = Parser, + parser: Optional[Callable, Type[Parser]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py index 86595eaefa..af16833f0c 100644 --- a/flash_examples/instance_segmentation.py +++ b/flash_examples/instance_segmentation.py @@ -24,7 +24,7 @@ # 1. Create the DataModule data_dir = icedata.pets.load_data() -datamodule = InstanceSegmentationData.from_folders( +datamodule = InstanceSegmentationData.from_icedata( train_folder=data_dir, val_split=0.1, parser=partial(icedata.pets.parser, mask=True), diff --git a/flash_examples/integrations/fiftyone/object_detection.py b/flash_examples/integrations/fiftyone/object_detection.py index ef3e25c25c..ac837f5ff7 100644 --- a/flash_examples/integrations/fiftyone/object_detection.py +++ b/flash_examples/integrations/fiftyone/object_detection.py @@ -26,7 +26,7 @@ # 1. Create the DataModule data_dir = icedata.fridge.load_data() -datamodule = ObjectDetectionData.from_folders( +datamodule = ObjectDetectionData.from_icedata( train_folder=data_dir, predict_folder=data_dir, val_split=0.1, diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py index cc7e4aa0d2..a94372e252 100644 --- a/flash_examples/keypoint_detection.py +++ b/flash_examples/keypoint_detection.py @@ -22,7 +22,7 @@ # 1. Create the DataModule data_dir = icedata.biwi.load_data() -datamodule = KeypointDetectionData.from_folders( +datamodule = KeypointDetectionData.from_icedata( train_folder=data_dir, val_split=0.1, parser=icedata.biwi.parser, diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py index 7527c276de..4588c1445b 100644 --- a/tests/examples/test_integrations.py +++ b/tests/examples/test_integrations.py @@ -34,6 +34,13 @@ not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" ), ), + pytest.param( + "fiftyone", + "object_detection.py", + marks=pytest.mark.skipif( + not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" + ), + ), pytest.param( "baal", "image_classification_active_learning.py", diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 78ebb08bf1..789f3345bf 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -17,7 +17,7 @@ import pytest -from flash.core.utilities.imports import _SKLEARN_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _SKLEARN_AVAILABLE from tests.examples.utils import run_test from tests.helpers.utils import ( _AUDIO_TESTING, @@ -52,6 +52,24 @@ "image_classification_multi_label.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), + pytest.param( + "object_detection.py", + marks=pytest.mark.skipif( + not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + ), + ), + pytest.param( + "instance_segmentation.py", + marks=pytest.mark.skipif( + not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + ), + ), + pytest.param( + "keypoint_detection.py", + marks=pytest.mark.skipif( + not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + ), + ), # pytest.param("finetuning", "object_detection.py"), # TODO: takes too long. pytest.param( "question_answering.py", From 3a3b4a42456380e46ddb6750b8ea1651f1bfbf8d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 18:49:50 +0000 Subject: [PATCH 2/6] Fix typing --- flash/image/detection/data.py | 4 ++-- flash/image/instance_segmentation/data.py | 4 ++-- flash/image/keypoint_detection/data.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index bba8a852f9..ffd8871187 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState @@ -58,7 +58,7 @@ def from_icedata( val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - parser: Optional[Callable, Type[Parser]] = None, + parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 6f2c1dad4f..cd6b267059 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState @@ -56,7 +56,7 @@ def from_icedata( val_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform, - parser: Optional[Callable, Type[Parser]] = None, + parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 1b544fd33f..3e0dfb366e 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState @@ -47,7 +47,7 @@ def from_icedata( val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - parser: Optional[Callable, Type[Parser]] = None, + parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, From 5c35b32b972a763b24e3f12ae4fd84e22c81e0a0 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 19:02:06 +0000 Subject: [PATCH 3/6] Fixes --- flash_examples/instance_segmentation.py | 4 +++- flash_examples/integrations/fiftyone/object_detection.py | 3 ++- flash_examples/keypoint_detection.py | 4 +++- flash_examples/object_detection.py | 7 +++++-- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py index af16833f0c..cba5e6e8c1 100644 --- a/flash_examples/instance_segmentation.py +++ b/flash_examples/instance_segmentation.py @@ -28,6 +28,7 @@ train_folder=data_dir, val_split=0.1, parser=partial(icedata.pets.parser, mask=True), + batch_size=4, ) # 2. Build the task @@ -47,7 +48,8 @@ str(data_dir / "images/yorkshire_terrier_9.jpg"), str(data_dir / "images/yorkshire_terrier_12.jpg"), str(data_dir / "images/yorkshire_terrier_13.jpg"), - ] + ], + batch_size=4, ) predictions = trainer.predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/integrations/fiftyone/object_detection.py b/flash_examples/integrations/fiftyone/object_detection.py index ac837f5ff7..e3e1f8e99a 100644 --- a/flash_examples/integrations/fiftyone/object_detection.py +++ b/flash_examples/integrations/fiftyone/object_detection.py @@ -30,8 +30,9 @@ train_folder=data_dir, predict_folder=data_dir, val_split=0.1, - image_size=128, + transform_kwargs={"image_size": 128}, parser=icedata.fridge.parser, + batch_size=4, ) # 2. Build the task diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py index a94372e252..e04636da98 100644 --- a/flash_examples/keypoint_detection.py +++ b/flash_examples/keypoint_detection.py @@ -26,6 +26,7 @@ train_folder=data_dir, val_split=0.1, parser=icedata.biwi.parser, + batch_size=4, ) # 2. Build the task @@ -46,7 +47,8 @@ str(data_dir / "biwi_sample/images/0.jpg"), str(data_dir / "biwi_sample/images/1.jpg"), str(data_dir / "biwi_sample/images/10.jpg"), - ] + ], + batch_size=4, ) predictions = trainer.predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index e10b813253..4fae590614 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -23,7 +23,8 @@ train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, - image_size=128, + transform_kwargs={"image_size": 128}, + batch_size=4, ) # 2. Build the task @@ -39,7 +40,9 @@ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", "data/coco128/images/train2017/000000000629.jpg", - ] + ], + transform_kwargs={"image_size": 128}, + batch_size=4, ) predictions = trainer.predict(model, datamodule=datamodule) print(predictions) From 0f49513c24bc5c441d4ecc00e4ccc381030ca94c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 20:19:46 +0000 Subject: [PATCH 4/6] Fixes --- flash/core/data/data_module.py | 149 +++++++++++++++++--------- flash/core/data/io/input_transform.py | 13 ++- 2 files changed, 107 insertions(+), 55 deletions(-) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 5b1fd02600..f273d18383 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -28,7 +28,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.io.input import DataKeys, Input, InputBase, IterableInput -from flash.core.data.io.input_transform import InputTransform +from flash.core.data.io.input_transform import _InputTransformProcessorV2, InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX @@ -193,8 +193,18 @@ def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[C return ds._create_on_after_batch_transfer_fn([self.data_fetcher]) def _train_dataloader(self) -> DataLoader: + if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) + train_ds: Input = self._train_input collate_fn = self._train_dataloader_collate_fn + + transform_processor = None + if isinstance(collate_fn, _InputTransformProcessorV2): + transform_processor = collate_fn + collate_fn = transform_processor.collate_fn + shuffle: bool = False if isinstance(train_ds, IterableDataset): drop_last = False @@ -208,9 +218,7 @@ def _train_dataloader(self) -> DataLoader: sampler = self.sampler(train_ds) if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_train_dataset( + dataloader = self.trainer.lightning_module.process_train_dataset( train_ds, trainer=self.trainer, batch_size=self.batch_size, @@ -221,27 +229,40 @@ def _train_dataloader(self) -> DataLoader: collate_fn=collate_fn, sampler=sampler, ) + else: + dataloader = DataLoader( + train_ds, + batch_size=self.batch_size, + shuffle=shuffle, + sampler=sampler, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=drop_last, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) - return DataLoader( - train_ds, - batch_size=self.batch_size, - shuffle=shuffle, - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - drop_last=drop_last, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + if transform_processor is not None: + transform_processor.collate_fn = dataloader.collate_fn + dataloader.collate_fn = transform_processor + + return dataloader def _val_dataloader(self) -> DataLoader: + if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) + val_ds: Input = self._val_input collate_fn = self._val_dataloader_collate_fn + transform_processor = None + if isinstance(collate_fn, _InputTransformProcessorV2): + transform_processor = collate_fn + collate_fn = transform_processor.collate_fn + if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_val_dataset( + dataloader = self.trainer.lightning_module.process_val_dataset( val_ds, trainer=self.trainer, batch_size=self.batch_size, @@ -249,24 +270,37 @@ def _val_dataloader(self) -> DataLoader: pin_memory=self.pin_memory, collate_fn=collate_fn, ) + else: + dataloader = DataLoader( + val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) - return DataLoader( - val_ds, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + if transform_processor is not None: + transform_processor.collate_fn = dataloader.collate_fn + dataloader.collate_fn = transform_processor + + return dataloader def _test_dataloader(self) -> DataLoader: + if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) + test_ds: Input = self._test_input collate_fn = self._test_dataloader_collate_fn + transform_processor = None + if isinstance(collate_fn, _InputTransformProcessorV2): + transform_processor = collate_fn + collate_fn = transform_processor.collate_fn + if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_test_dataset( + dataloader = self.trainer.lightning_module.process_test_dataset( test_ds, trainer=self.trainer, batch_size=self.batch_size, @@ -274,44 +308,63 @@ def _test_dataloader(self) -> DataLoader: pin_memory=self.pin_memory, collate_fn=collate_fn, ) + else: + dataloader = DataLoader( + test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) - return DataLoader( - test_ds, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + if transform_processor is not None: + transform_processor.collate_fn = dataloader.collate_fn + dataloader.collate_fn = transform_processor + + return dataloader def _predict_dataloader(self) -> DataLoader: + if isinstance(getattr(self, "trainer", None), pl.Trainer): + if isinstance(self.trainer.lightning_module, flash.Task): + self.connect(self.trainer.lightning_module) + predict_ds: Input = self._predict_input collate_fn = self._predict_dataloader_collate_fn + transform_processor = None + if isinstance(collate_fn, _InputTransformProcessorV2): + transform_processor = collate_fn + collate_fn = transform_processor.collate_fn + if isinstance(predict_ds, IterableDataset): batch_size = self.batch_size else: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - return self.trainer.lightning_module.process_predict_dataset( + dataloader = self.trainer.lightning_module.process_predict_dataset( predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=collate_fn, ) + else: + dataloader = DataLoader( + predict_ds, + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) - return DataLoader( - predict_ds, - batch_size=batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + if transform_processor is not None: + transform_processor.collate_fn = dataloader.collate_fn + dataloader.collate_fn = transform_processor + + return dataloader def connect(self, task: "flash.Task"): data_pipeline_state = DataPipelineState() diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index 27fecec1b2..f61c16dd7a 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -16,7 +16,6 @@ from functools import partial, wraps from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union -import torch from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data._utils.collate import default_collate @@ -33,7 +32,7 @@ PerSampleTransformOnDevice, ) from flash.core.data.transforms import ApplyToKeys -from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX, convert_to_modules +from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -1118,7 +1117,7 @@ def _make_collates(input_transform: "InputTransform", on_device: bool, collate: return collate, input_transform._identity -class _InputTransformProcessorV2(torch.nn.Module): +class _InputTransformProcessorV2: """ This class is used to encapsulate the following functions of a InputTransformInputTransform Object: Inside a worker: @@ -1146,9 +1145,9 @@ def __init__( super().__init__() self.input_transform = input_transform self.callback = ControlFlow(callbacks or []) - self.collate_fn = convert_to_modules(collate_fn) - self.per_sample_transform = convert_to_modules(per_sample_transform) - self.per_batch_transform = convert_to_modules(per_batch_transform) + self.collate_fn = collate_fn + self.per_sample_transform = per_sample_transform + self.per_batch_transform = per_batch_transform self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage self.on_device = on_device @@ -1160,7 +1159,7 @@ def _extract_metadata( metadata = [s.pop(DataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] return samples, metadata if any(m is not None for m in metadata) else None - def forward(self, samples: Sequence[Any]) -> Any: + def __call__(self, samples: Sequence[Any]) -> Any: if not self.on_device: for sample in samples: self.callback.on_load_sample(sample, self.stage) From 07318489c3cb2a39f48e4eae601018a01479f904 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 21:05:11 +0000 Subject: [PATCH 5/6] Fixes --- flash/image/keypoint_detection/data.py | 24 +++++++++---------- .../keypoint_detection/input_transform.py | 24 +++++++++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) create mode 100644 flash/image/keypoint_detection/input_transform.py diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 3e0dfb366e..751a7513ec 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -17,10 +17,10 @@ from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.integrations.icevision.data import IceVisionInput -from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE +from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform if _ICEVISION_AVAILABLE: from icevision.parsers import COCOKeyPointsParser, Parser @@ -31,7 +31,7 @@ class KeypointDetectionData(DataModule): - input_transform_cls = IceVisionInputTransform + input_transform_cls = KeypointDetectionInputTransform @classmethod def from_icedata( @@ -43,10 +43,10 @@ def from_icedata( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, @@ -73,10 +73,10 @@ def from_coco( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, parser: Optional[Type[Parser]] = COCOKeyPointsParser, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, @@ -119,7 +119,7 @@ def from_coco( def from_folders( cls, predict_folder: Optional[str] = None, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -149,7 +149,7 @@ def from_folders( def from_files( cls, predict_files: Optional[List[str]] = None, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, diff --git a/flash/image/keypoint_detection/input_transform.py b/flash/image/keypoint_detection/input_transform.py new file mode 100644 index 0000000000..bd741bd822 --- /dev/null +++ b/flash/image/keypoint_detection/input_transform.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.integrations.icevision.transforms import IceVisionInputTransform, IceVisionTransformAdapter +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires + +if _ICEVISION_AVAILABLE: + from icevision.tfms import A + + +class KeypointDetectionInputTransform(IceVisionInputTransform): + @requires(["image", "icevision"]) + def train_per_sample_transform(self): + return IceVisionTransformAdapter([*A.aug_tfms(size=self.image_size, crop_fn=None), A.Normalize()]) From a5b71eabc998d5ec966d150203943f7c4dba3990 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 14 Dec 2021 22:25:54 +0000 Subject: [PATCH 6/6] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7238b28bb..20efe5a897 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug when not explicitly passing `embedding_sizes` to the `TabularClassifier` and `TabularRegressor` tasks ([#1067](https://github.com/PyTorchLightning/lightning-flash/pull/1067)) +- Fixed a bug where under some circumstances transforms would not get called ([#1072](https://github.com/PyTorchLightning/lightning-flash/pull/1072)) + ### Removed ## [0.6.0] - 2021-13-12