Skip to content

Commit

Permalink
EuroSATDataModule: set mean/std based on bands (#1681)
Browse files Browse the repository at this point in the history
* Use dicts to generate mean and std

* correctly pass bands

* black format

* Remove unused import

* Simplify

* Import all bands
  • Loading branch information
robmarkcole authored and nilsleh committed Nov 6, 2023
1 parent 1b48cb0 commit a6a11fd
Showing 1 changed file with 39 additions and 41 deletions.
80 changes: 39 additions & 41 deletions torchgeo/datamodules/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,37 @@
from ..datasets import EuroSAT, EuroSAT100
from .geo import NonGeoDataModule

MEAN = torch.tensor(
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)

STD = torch.tensor(
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)
MEAN = {
"B01": 1354.40546513,
"B02": 1118.24399958,
"B03": 1042.92983953,
"B04": 947.62620298,
"B05": 1199.47283961,
"B06": 1999.79090914,
"B07": 2369.22292565,
"B08": 2296.82608323,
"B8A": 732.08340178,
"B09": 12.11327804,
"B10": 1819.01027855,
"B11": 1118.92391149,
"B12": 2594.14080798,
}

STD = {
"B01": 245.71762908,
"B02": 333.00778264,
"B03": 395.09249139,
"B04": 593.75055589,
"B05": 566.4170017,
"B06": 861.18399006,
"B07": 1086.63139075,
"B08": 1117.98170791,
"B8A": 404.91978886,
"B09": 4.77584468,
"B10": 1002.58768311,
"B11": 761.30323499,
"B12": 1231.58581042,
}


class EuroSATDataModule(NonGeoDataModule):
Expand All @@ -55,9 +51,6 @@ class EuroSATDataModule(NonGeoDataModule):
.. versionadded:: 0.2
"""

mean = MEAN
std = STD

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
Expand All @@ -71,6 +64,10 @@ def __init__(
"""
super().__init__(EuroSAT, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", EuroSAT.all_band_names)
self.mean = torch.tensor([MEAN[b] for b in bands])
self.std = torch.tensor([STD[b] for b in bands])


class EuroSAT100DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the EuroSAT100 dataset.
Expand All @@ -80,9 +77,6 @@ class EuroSAT100DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""

mean = MEAN
std = STD

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
Expand All @@ -95,3 +89,7 @@ def __init__(
:class:`~torchgeo.datasets.EuroSAT100`.
"""
super().__init__(EuroSAT100, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", EuroSAT.all_band_names)
self.mean = torch.tensor([MEAN[b] for b in bands])
self.std = torch.tensor([STD[b] for b in bands])

0 comments on commit a6a11fd

Please sign in to comment.