From a6a11fd5b89cbd40f707815f20777aeb2e75fdfc Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Fri, 20 Oct 2023 10:40:40 +0100 Subject: [PATCH] EuroSATDataModule: set mean/std based on bands (#1681) * Use dicts to generate mean and std * correctly pass bands * black format * Remove unused import * Simplify * Import all bands --- torchgeo/datamodules/eurosat.py | 80 ++++++++++++++++----------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index b4267cfc50a..ccf2a90f691 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -10,41 +10,37 @@ from ..datasets import EuroSAT, EuroSAT100 from .geo import NonGeoDataModule -MEAN = torch.tensor( - [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798, - ] -) - -STD = torch.tensor( - [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042, - ] -) +MEAN = { + "B01": 1354.40546513, + "B02": 1118.24399958, + "B03": 1042.92983953, + "B04": 947.62620298, + "B05": 1199.47283961, + "B06": 1999.79090914, + "B07": 2369.22292565, + "B08": 2296.82608323, + "B8A": 732.08340178, + "B09": 12.11327804, + "B10": 1819.01027855, + "B11": 1118.92391149, + "B12": 2594.14080798, +} + +STD = { + "B01": 245.71762908, + "B02": 333.00778264, + "B03": 395.09249139, + "B04": 593.75055589, + "B05": 566.4170017, + "B06": 861.18399006, + "B07": 1086.63139075, + "B08": 1117.98170791, + "B8A": 404.91978886, + "B09": 4.77584468, + "B10": 1002.58768311, + "B11": 761.30323499, + "B12": 1231.58581042, +} class EuroSATDataModule(NonGeoDataModule): @@ -55,9 +51,6 @@ class EuroSATDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ - mean = MEAN - std = STD - def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: @@ -71,6 +64,10 @@ def __init__( """ super().__init__(EuroSAT, batch_size, num_workers, **kwargs) + bands = kwargs.get("bands", EuroSAT.all_band_names) + self.mean = torch.tensor([MEAN[b] for b in bands]) + self.std = torch.tensor([STD[b] for b in bands]) + class EuroSAT100DataModule(NonGeoDataModule): """LightningDataModule implementation for the EuroSAT100 dataset. @@ -80,9 +77,6 @@ class EuroSAT100DataModule(NonGeoDataModule): .. versionadded:: 0.5 """ - mean = MEAN - std = STD - def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: @@ -95,3 +89,7 @@ def __init__( :class:`~torchgeo.datasets.EuroSAT100`. """ super().__init__(EuroSAT100, batch_size, num_workers, **kwargs) + + bands = kwargs.get("bands", EuroSAT.all_band_names) + self.mean = torch.tensor([MEAN[b] for b in bands]) + self.std = torch.tensor([STD[b] for b in bands])