Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update doc object detection (#1110)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Actis Grosso <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Jan 12, 2022
1 parent 1739ada commit 127b7c0
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here's an outline:
...
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we can create the :class:`~flash.image.detection.data.ObjectDetectionData`.
We select a pre-trained RetinaNet to use for our :class:`~flash.image.detection.model.ObjectDetector` and fine-tune on the COCO 128 data.
We select a pre-trained EfficientDet to use for our :class:`~flash.image.detection.model.ObjectDetector` and fine-tune on the COCO 128 data.
We then use the trained :class:`~flash.image.detection.model.ObjectDetector` for inference.
Finally, we save the model.
Here's the full example:
Expand Down Expand Up @@ -82,26 +82,34 @@ Custom Transformations

Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case.
The base :class:`~flash.core.data.io.input_transform.InputTransform` defines 7 hooks for different stages in the data loading pipeline.
For object-detection tasks, you can leverage the transformations from `Albumentations <https://github.com/albumentations-team/albumentations>`__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`.
For object-detection tasks, you can leverage the transformations from `Albumentations <https://github.com/albumentations-team/albumentations>`__ with the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter`,
creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransform`

.. code-block:: python
from dataclasses import dataclass
import albumentations as alb
from icevision.tfms import A
from flash import InputTransform
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
train_transform = {
"per_sample_transform": transforms.IceVisionTransformAdapter(
[*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()]
)
}
@dataclass
class BrightnessContrastTransform(InputTransform):
image_size: int = 128
def per_sample_transform(self):
return IceVisionTransformAdapter(
[*A.aug_tfms(size=self.image_size), A.Normalize(), alb.RandomBrightnessContrast()]
)
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
image_size=128,
train_transform=train_transform,
train_transform=BrightnessContrastTransform,
batch_size=4,
)

0 comments on commit 127b7c0

Please sign in to comment.