Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add from numpy, tensors, images to object detection data
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jun 30, 2022
1 parent b1da344 commit dc16e80
Show file tree
Hide file tree
Showing 4 changed files with 486 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for loading `ImageClassificationData` from PIL images with `from_images` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

- Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
28 changes: 14 additions & 14 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,13 @@ def from_numpy(
@classmethod
def from_images(
cls,
train_data: Optional[List[Image.Image]] = None,
train_images: Optional[List[Image.Image]] = None,
train_targets: Optional[Sequence[Any]] = None,
val_data: Optional[List[Image.Image]] = None,
val_images: Optional[List[Image.Image]] = None,
val_targets: Optional[Sequence[Any]] = None,
test_data: Optional[List[Image.Image]] = None,
test_images: Optional[List[Image.Image]] = None,
test_targets: Optional[Sequence[Any]] = None,
predict_data: Optional[List[Image.Image]] = None,
predict_images: Optional[List[Image.Image]] = None,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = ImageClassificationImageInput,
transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform,
Expand All @@ -412,13 +412,13 @@ def from_images(
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_data: The list of PIL images to use when training.
train_images: The list of PIL images to use when training.
train_targets: The list of targets to use when training.
val_data: The list of PIL images to use when validating.
val_images: The list of PIL images to use when validating.
val_targets: The list of targets to use when validating.
test_data: The list of PIL images to use when testing.
test_images: The list of PIL images to use when testing.
test_targets: The list of targets to use when testing.
predict_data: The list of PIL images to use when predicting.
predict_images: The list of PIL images to use when predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
Expand All @@ -440,13 +440,13 @@ def from_images(
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_images(
... train_data=[
... train_images=[
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")),
... ],
... train_targets=["cat", "dog", "cat"],
... predict_data=[Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))],
... predict_images=[Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
Expand All @@ -465,14 +465,14 @@ def from_images(
target_formatter=target_formatter,
)

train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw)
train_input = input_cls(RunningStage.TRAINING, train_images, train_targets, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, **ds_kw),
input_cls(RunningStage.VALIDATING, val_images, val_targets, **ds_kw),
input_cls(RunningStage.TESTING, test_images, test_targets, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_images, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
Expand Down
Loading

0 comments on commit dc16e80

Please sign in to comment.