forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
USAVars: implementing DataModule (microsoft#441)
* USAVars: implementing DataModule * Adding initial version * add to __init__ * changes * add transforms argument * black, isort fix * fixed shuffle option * update docs * fix formatting * initial split method * fix formatting * testing for datamodule * this is simpler * testing seed * fix isort + test seed * refactor dataset for splits * black fix * adding splits to fake data * change test splits * working tests locally * fix black * fix black * adapt module to dataset refactor * complete docstring * Style fixes Co-authored-by: Adam J. Stewart <[email protected]>
- Loading branch information
1 parent
2c1477a
commit 8d89b70
Showing
10 changed files
with
241 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tile_0,0.tif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
tile_0,0.tif | ||
tile_0,0.tif | ||
tile_0,0.tif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tile_0,1.tif | ||
tile_0,1.tif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
|
||
import pytest | ||
from _pytest.fixtures import SubRequest | ||
|
||
from torchgeo.datamodules import USAVarsDataModule | ||
|
||
|
||
class TestUSAVarsDataModule: | ||
@pytest.fixture() | ||
def datamodule(self, request: SubRequest) -> USAVarsDataModule: | ||
root = os.path.join("tests", "data", "usavars") | ||
batch_size = 1 | ||
num_workers = 0 | ||
|
||
dm = USAVarsDataModule(root, batch_size=batch_size, num_workers=num_workers) | ||
dm.prepare_data() | ||
dm.setup() | ||
return dm | ||
|
||
def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None: | ||
assert len(datamodule.train_dataloader()) == 3 | ||
sample = next(iter(datamodule.train_dataloader())) | ||
assert sample["image"].shape[0] == datamodule.batch_size | ||
|
||
def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None: | ||
assert len(datamodule.val_dataloader()) == 2 | ||
sample = next(iter(datamodule.val_dataloader())) | ||
assert sample["image"].shape[0] == datamodule.batch_size | ||
|
||
def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: | ||
assert len(datamodule.test_dataloader()) == 1 | ||
sample = next(iter(datamodule.test_dataloader())) | ||
assert sample["image"].shape[0] == datamodule.batch_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""USAVars datamodule.""" | ||
|
||
from typing import Any, Callable, Dict, Optional, Sequence | ||
|
||
import pytorch_lightning as pl | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
|
||
from ..datasets import USAVars | ||
|
||
|
||
class USAVarsDataModule(pl.LightningModule): | ||
"""LightningDataModule implementation for the USAVars dataset. | ||
Uses random train/val/test splits. | ||
.. versionadded:: 0.3 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root_dir: str, | ||
labels: Sequence[str] = USAVars.ALL_LABELS, | ||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, | ||
batch_size: int = 64, | ||
num_workers: int = 0, | ||
) -> None: | ||
"""Initialize a LightningDataModule for USAVars based DataLoaders. | ||
Args: | ||
root_dir: The root argument passed to the USAVars Dataset classes | ||
labels: The labels argument passed to the USAVars Dataset classes | ||
transforms: a function/transform that takes input sample and its target as | ||
entry and returns a transformed version | ||
batch_size: The batch size to use in all created DataLoaders | ||
num_workers: The number of workers to use in all created DataLoaders | ||
""" | ||
super().__init__() | ||
self.root_dir = root_dir | ||
self.labels = labels | ||
self.transforms = transforms | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
|
||
def prepare_data(self) -> None: | ||
"""Make sure that the dataset is downloaded. | ||
This method is only called once per run. | ||
""" | ||
USAVars(self.root_dir, labels=self.labels, checksum=False) | ||
|
||
def setup(self, stage: Optional[str] = None) -> None: | ||
"""Initialize the main Dataset objects. | ||
This method is called once per GPU per run. | ||
""" | ||
self.train_dataset = USAVars( | ||
self.root_dir, "train", self.labels, transforms=self.transforms | ||
) | ||
self.val_dataset = USAVars( | ||
self.root_dir, "val", self.labels, transforms=self.transforms | ||
) | ||
self.test_dataset = USAVars( | ||
self.root_dir, "test", self.labels, transforms=self.transforms | ||
) | ||
|
||
def train_dataloader(self) -> DataLoader[Any]: | ||
"""Return a DataLoader for training.""" | ||
return DataLoader( | ||
self.train_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
) | ||
|
||
def val_dataloader(self) -> DataLoader[Any]: | ||
"""Return a DataLoader for validation.""" | ||
return DataLoader( | ||
self.val_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
) | ||
|
||
def test_dataloader(self) -> DataLoader[Any]: | ||
"""Return a DataLoader for testing.""" | ||
return DataLoader( | ||
self.test_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters