Skip to content

Commit

Permalink
InriaAerialImageLabelingDataModule: fix predict dimensions (microsoft…
Browse files Browse the repository at this point in the history
…#975)

* InriaAerialImageLabelingDataModule: fix predict dimensions

* Record number of patches for reconstruction
  • Loading branch information
adamjstewart authored Dec 26, 2022
1 parent e3741ec commit 321387e
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 70 deletions.
File renamed without changes.
20 changes: 20 additions & 0 deletions tests/conf/inria_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
experiment:
task: "inria"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 2
ignore_index: null
datamodule:
root: "tests/data/inria"
batch_size: 1
num_workers: 0
val_split_pct: 0.0
test_split_pct: 0.0
patch_size: 2
num_patches_per_tile: 2
20 changes: 20 additions & 0 deletions tests/conf/inria_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
experiment:
task: "inria"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 2
ignore_index: null
datamodule:
root: "tests/data/inria"
batch_size: 1
num_workers: 0
val_split_pct: 0.2
test_split_pct: 0.0
patch_size: 2
num_patches_per_tile: 2
67 changes: 0 additions & 67 deletions tests/datamodules/test_inria.py

This file was deleted.

8 changes: 6 additions & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class TestSemanticSegmentationTask:
("deepglobelandcover_0", DeepGlobeLandCoverDataModule),
("deepglobelandcover_5", DeepGlobeLandCoverDataModule),
("etci2021", ETCI2021DataModule),
("inria", InriaAerialImageLabelingDataModule),
("inria_train", InriaAerialImageLabelingDataModule),
("inria_val", InriaAerialImageLabelingDataModule),
("inria_test", InriaAerialImageLabelingDataModule),
("landcoverai", LandCoverAIDataModule),
("naipchesapeake", NAIPChesapeakeDataModule),
("oscd_all", OSCDDataModule),
Expand Down Expand Up @@ -80,7 +82,9 @@ def test_trainer(
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())

if hasattr(datamodule, "predict_dataset"):
trainer.predict(model=model, datamodule=datamodule)

def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml"))
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
self.patch_size,
padding=padding,
)
sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w")
# Needed for reconstruction of patches later
sample["num_patches"] = sample["image"].shape[1]
sample["image"] = rearrange(sample["image"], "b n c h w -> (b n) c h w")
return sample

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 321387e

Please sign in to comment.