Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Update formatting targets guide #1165

Merged
merged 7 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where backbones for the `ObjectDetector`, `KeypointDetector`, and `InstanceSegmentation` tasks were not always frozen correctly when finetuning ([#1163](https://github.com/PyTorchLightning/lightning-flash/pull/1163))

- Fixed a bug where `DataModule.multi_label` would sometimes be `None` when it had been inferred to be `False` ([#1165](https://github.com/PyTorchLightning/lightning-flash/pull/1165))

### Removed

- Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))
Expand Down
211 changes: 210 additions & 1 deletion docs/source/general/classification_targets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,213 @@
Formatting Classification Targets
*********************************

.. note:: The contents of this page are currently being updated. Stay tuned!
This guide details the different target formats supported by classification tasks in Flash.
By default, the target format and any additional metadata (`labels`, `num_classes`, `multi_label`) will be inferred from your training data.

.. testsetup:: targets

import numpy as np
from PIL import Image

rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
_ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]

Single Numeric
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
______________

Single numeric targets are represented by a single integer (`multi_label = False`).
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
No `labels` will be inferred.
The inferred `num_classes` is the maximum target value plus one (we assume that targets are zero-based).
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[0, 1, 0],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False

Single Labels
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
_____________

Single labels are targets represented by a single string (`multi_label = False`).
The inferred `labels` will be the unique labels in the train targets sorted alphanumerically.
The inferred `num_classes` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "dog", "cat"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> datamodule.multi_label
False

Single Binary
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
_____________

Single binary targets are represented by a one-hot encoded binary list (`multi_label = False`).
No `labels` will be inferred.
The inferred `num_classes` is the length of the binary list.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0], [0, 1], [1, 0]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False

Multi Numeric
_____________

Multi numeric targets are represented by a list of integer class indexes (`multi_label = True`).
No `labels` will be inferred.
The inferred `num_classes` is the maximum target value plus one (we assume that targets are zero-based).
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[0], [0, 1], [1, 2]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True

Multi Labels
____________

Multi labels are targets represented by a list of strings (`multi_label = True`).
The inferred `labels` will be the unique labels in the train targets sorted alphanumerically.
The inferred `num_classes` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
["cat", "dog", "rabbit"]
>>> datamodule.multi_label
True

Comma Delimited
_______________

Comma delimited targets are mutli label targets where the labels are given as comma delimited strings (`multi_label = True`).
The inferred `labels` will be the unique labels in the train targets sorted alphanumerically.
The inferred `num_classes` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat,dog", "dog,rabbit"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
["cat", "dog", "rabbit"]
>>> datamodule.multi_label
True

Space Delimited
_______________

Space delimited targets are mutli label targets where the labels are given as space delimited strings (`multi_label = True`).
The inferred `labels` will be the unique labels in the train targets sorted alphanumerically.
The inferred `num_classes` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat dog", "dog rabbit"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
["cat", "dog", "rabbit"]
>>> datamodule.multi_label
True

Multi Binary
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
____________

Multi binary targets are represented by a multi-hot encoded binary list (`multi_label = False`).
No `labels` will be inferred.
The inferred `num_classes` is the length of the binary list.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True
22 changes: 10 additions & 12 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,29 +481,27 @@ def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample",
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._show_batch(stage_name, hooks_names, reset=reset)

def _get_property(self, property_name: str) -> Optional[Any]:
train = getattr(self.train_dataset, property_name, None)
val = getattr(self.val_dataset, property_name, None)
test = getattr(self.test_dataset, property_name, None)
filtered = list(filter(lambda x: x is not None, [train, val, test]))
return filtered[0] if len(filtered) > 0 else None

@property
def num_classes(self) -> Optional[int]:
"""Property that returns the number of classes of the datamodule if a multiclass task."""
n_cls_train = getattr(self.train_dataset, "num_classes", None)
n_cls_val = getattr(self.val_dataset, "num_classes", None)
n_cls_test = getattr(self.test_dataset, "num_classes", None)
return n_cls_train or n_cls_val or n_cls_test
return self._get_property("num_classes")

@property
def labels(self) -> Optional[int]:
"""Property that returns the labels if this ``DataModule`` contains classification data."""
n_cls_train = getattr(self.train_dataset, "labels", None)
n_cls_val = getattr(self.val_dataset, "labels", None)
n_cls_test = getattr(self.test_dataset, "labels", None)
return n_cls_train or n_cls_val or n_cls_test
return self._get_property("labels")

@property
def multi_label(self) -> Optional[bool]:
"""Property that returns ``True`` if this ``DataModule`` contains multi-label data."""
multi_label_train = getattr(self.train_dataset, "multi_label", None)
multi_label_val = getattr(self.val_dataset, "multi_label", None)
multi_label_test = getattr(self.test_dataset, "multi_label", None)
return multi_label_train or multi_label_val or multi_label_test
return self._get_property("multi_label")

@property
def inputs(self) -> Optional[Union[Input, List[InputBase]]]:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/utilities/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]:
MultiBinaryTargetFormatter: [MultiNumericTargetFormatter],
SingleBinaryTargetFormatter: [MultiBinaryTargetFormatter, MultiNumericTargetFormatter],
SingleLabelTargetFormatter: [CommaDelimitedMultiLabelTargetFormatter, SpaceDelimitedTargetFormatter],
SingleNumericTargetFormatter: [MultiNumericTargetFormatter],
SingleNumericTargetFormatter: [SingleBinaryTargetFormatter, MultiNumericTargetFormatter],
}


Expand Down
2 changes: 1 addition & 1 deletion tests/core/data/utilities/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
3,
),
# Ambiguous
Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], MultiNumericTargetFormatter, None, 3),
Case([[0], [0, 1], [1, 2]], [[1, 0, 0], [1, 1, 0], [0, 1, 1]], MultiNumericTargetFormatter, None, 3),
Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], MultiBinaryTargetFormatter, None, 3),
Case(
[["blue"], ["green", "red"], ["red", "blue"]],
Expand Down