Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VHR10 datamodule #1082

Merged
merged 33 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3c3a306
Add VHR10 datamodule
ashnair1 Feb 2, 2023
41cf32d
Add newline
ashnair1 Feb 2, 2023
1efc331
patch_size accepts int and tuple of ints
ashnair1 Feb 2, 2023
74c4bd5
Update conf
ashnair1 Feb 2, 2023
b4511ea
VHR10 Datamodule v2
ashnair1 Mar 30, 2023
f273f03
Remove auto_lr_find
ashnair1 Apr 10, 2023
f4a96a8
Remove preprocess
ashnair1 Apr 11, 2023
fec3449
Update config
ashnair1 May 4, 2023
738d996
Remove setting of matplotlib backend
ashnair1 May 4, 2023
97bb234
Remove import
ashnair1 May 10, 2023
54a3116
Typing update
ashnair1 May 10, 2023
5bdb57a
Key fix
ashnair1 Jun 29, 2023
bb4fd46
Coverage fix
ashnair1 Jun 29, 2023
f31c82f
Update conf
ashnair1 Oct 24, 2023
bf69bac
Update conf
ashnair1 Oct 24, 2023
4cd4568
Dowload=True
ashnair1 Oct 24, 2023
9d48f1f
Use weights
ashnair1 Oct 25, 2023
2bdbf0f
Empty commit
ashnair1 Oct 30, 2023
3efa9be
Switch to ndim
ashnair1 Nov 6, 2023
e3417e8
Remove conf, tight_layout and spacing
ashnair1 Nov 9, 2023
d4700cc
Set constrained layout via rcParams
ashnair1 Nov 9, 2023
756143f
Revert and bump min matplotlib version
ashnair1 Nov 9, 2023
f750cf4
Switch back to dataset_split
ashnair1 Nov 10, 2023
b8f0166
Separate out AugPipe
ashnair1 Nov 15, 2023
0f13514
Increase figsize & revert matplotlib
ashnair1 Nov 21, 2023
d976e5f
Common collate_fn
ashnair1 Nov 21, 2023
77d3feb
Class var std
ashnair1 Nov 21, 2023
9f21d06
Undo std change in BaseDataModule
ashnair1 Nov 23, 2023
61a1d46
Undo req changes
ashnair1 Nov 30, 2023
7e88d02
Remove unused line
ashnair1 Dec 6, 2023
30b02d6
Add version strings
ashnair1 Dec 12, 2023
1f74d9b
mypy fix
ashnair1 Jan 9, 2024
1961950
Merge branch 'main' into vhr10-datamodule
adamjstewart Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/conf/vhr10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: ObjectDetectionTask
init_args:
model: "faster-rcnn"
backbone: "resnet50"
num_classes: 11
lr: 2.5e-5
patience: 10
data:
class_path: VHR10DataModule
init_args:
batch_size: 1
num_workers: 0
patch_size: 4
dict_kwargs:
root: "tests/data/vhr10"
download: true
Binary file modified tests/data/vhr10/NWPU VHR-10 dataset.rar
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/vhr10/annotations.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
10 changes: 2 additions & 8 deletions tests/data/vhr10/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import shutil
import subprocess
from copy import deepcopy

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -47,7 +46,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str:
)

ann = 0
for i, img in enumerate(ANNOTATION_FILE["images"]):
for _, img in enumerate(ANNOTATION_FILE["images"]):
annot = {
"id": ann,
"image_id": img["id"],
Expand All @@ -57,12 +56,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str:
"segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]],
"iscrowd": 0,
}
if i != 0:
ANNOTATION_FILE["annotations"].append(annot)
else:
noseg_annot = deepcopy(annot)
del noseg_annot["segmentation"]
ANNOTATION_FILE["annotations"].append(noseg_annot)
ANNOTATION_FILE["annotations"].append(annot)
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
ann += 1

with open(ann_file, "w") as j:
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def dataset(
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar")
monkeypatch.setitem(VHR10.image_meta, "url", url)
md5 = "5fddb0dfd56a80638831df9f90cbf37a"
md5 = "92769845cae6a4e8c74bfa1a0d1d4a80"
monkeypatch.setitem(VHR10.image_meta, "md5", md5)
url = os.path.join("tests", "data", "vhr10", "annotations.json")
monkeypatch.setitem(VHR10.target_meta, "url", url)
md5 = "833899cce369168e0d4ee420dac326dc"
md5 = "567c4cd8c12624864ff04865de504c58"
monkeypatch.setitem(VHR10.target_meta, "md5", md5)
root = str(tmp_path)
split = request.param
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def plot(*args: Any, **kwargs: Any) -> None:


class TestObjectDetectionTask:
@pytest.mark.parametrize("name", ["nasa_marine_debris"])
@pytest.mark.parametrize("name", ["nasa_marine_debris", "vhr10"])
@pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"])
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool
Expand Down
16 changes: 8 additions & 8 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]:
return {
"image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}

Expand All @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]:
dtype=torch.float,
),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}

Expand All @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]:
dtype=torch.float,
),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}

Expand All @@ -79,7 +79,7 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None:
expected = {
"image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand All @@ -102,7 +102,7 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None:
dtype=torch.float,
),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand All @@ -129,7 +129,7 @@ def test_augmentation_sequential_multispectral(
dtype=torch.float,
),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand All @@ -156,7 +156,7 @@ def test_augmentation_sequential_image_only(
dtype=torch.float,
),
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_sequential_transforms_augmentations(
dtype=torch.float,
),
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
"labels": torch.tensor([[0, 1]]),
}
train_transforms = transforms.AugmentationSequential(
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
from .vaihingen import Vaihingen2DDataModule
from .vhr10 import VHR10DataModule
from .xview import XView2DataModule

__all__ = (
Expand Down Expand Up @@ -79,6 +80,7 @@
"UCMercedDataModule",
"USAVarsDataModule",
"Vaihingen2DDataModule",
"VHR10DataModule",
"XView2DataModule",
# Base classes
"BaseDataModule",
Expand Down
32 changes: 13 additions & 19 deletions torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,13 @@

from typing import Any

import kornia.augmentation as K
import torch
from torch import Tensor

from ..datasets import NASAMarineDebris
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import dataset_split


def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.

Args:
batch: list of sample dicts return by dataset

Returns:
batch dict output
"""
output: dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch]
return output
from .utils import AugPipe, collate_fn_detection, dataset_split


class NASAMarineDebrisDataModule(NonGeoDataModule):
Expand All @@ -35,6 +20,8 @@ class NASAMarineDebrisDataModule(NonGeoDataModule):
.. versionadded:: 0.2
"""

std = torch.tensor(255)

def __init__(
self,
batch_size: int = 64,
Expand All @@ -58,7 +45,14 @@ def __init__(
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct

self.collate_fn = collate_fn
self.aug = AugPipe(
AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "boxes"]
),
batch_size,
)

self.collate_fn = collate_fn_detection

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
87 changes: 85 additions & 2 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

import math
from collections.abc import Iterable
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import numpy as np
from torch import Generator
import torch
from einops import rearrange
from torch import Generator, Tensor
from torch.nn import Module
from torch.utils.data import Subset, TensorDataset, random_split

from ..datasets import NonGeoDataset
Expand All @@ -19,6 +22,86 @@ class MisconfigurationException(Exception):
"""Exception used to inform users of misuse with Lightning."""


class AugPipe(Module):
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""Pipeline for applying augmentations sequentially on select data keys.

.. versionadded:: 0.6
"""

def __init__(
self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int
) -> None:
"""Initialize a new AugPipe instance.

Args:
augs: Augmentations to apply.
batch_size: Batch size
"""
super().__init__()
self.augs = augs
self.batch_size = batch_size

def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Apply the augmentation.

Args:
batch: Input batch.

Returns:
Augmented batch.
"""
batch_len = len(batch["image"])
for bs in range(batch_len):
batch_dict = {
"image": batch["image"][bs],
"labels": batch["labels"][bs],
"boxes": batch["boxes"][bs],
}

if "masks" in batch:
batch_dict["masks"] = batch["masks"][bs]

batch_dict = self.augs(batch_dict)

batch["image"][bs] = batch_dict["image"]
batch["labels"][bs] = batch_dict["labels"]
batch["boxes"][bs] = batch_dict["boxes"]

if "masks" in batch:
batch["masks"][bs] = batch_dict["masks"]

# Stack images
batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w")

return batch


def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""Custom collate fn for object detection and instance segmentation.

Args:
batch: list of sample dicts return by dataset

Returns:
batch dict output

.. versionadded:: 0.6
"""
output: dict[str, Any] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead use torchgeo.datasets.unbind_samples and add one extra line to modify labels?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unbind_samples is used to convert dict of lists into list of dicts. Here we're doing the opposite.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then what about concat/merge_samples? Those do the opposite.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean stack_samples? I think concat and merge samples were for GeoDatasets.

Either way, those functions can't be used since it assumes the samples can be stacked/concatenated. This is not the case here.

Copy link
Collaborator

@adamjstewart adamjstewart Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both GeoDatasets and NonGeoDatasets use the same sample dict format, so all collation functions are appropriate for both. Although these tend to be more used with GeoDatasets because the output also contains CRS and BoundingBox, which the default collation functions from PyTorch don't support.

I see now that your output dict has lists of Tensors instead of Tensors, which is indeed different than all other builtin collation functions. I guess this is required because each sample may have a different number of objects. Not sure how people normally handle collation for that case in PyTorch.

output["image"] = [sample["image"] for sample in batch]
output["boxes"] = [sample["boxes"].float() for sample in batch]
if "labels" in batch[0]:
output["labels"] = [sample["labels"] for sample in batch]
else:
output["labels"] = [
torch.tensor([1] * len(sample["boxes"])) for sample in batch
]

if "masks" in batch[0]:
output["masks"] = [sample["masks"] for sample in batch]
return output


def dataset_split(
dataset: Union[TensorDataset, NonGeoDataset],
val_pct: float,
Expand Down
Loading
Loading