Skip to content

Commit

Permalink
refactor: ruff changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathias Baumgartinger committed Dec 7, 2024
1 parent 61e2f2f commit f49b84f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
4 changes: 3 additions & 1 deletion tests/data/flair2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def create_metadata(root_dir: str) -> None:
if os.path.exists(root_dir):
shutil.rmtree(root_dir)

def create_dir_structure(root_dir: str, dir_names: dict[str, dict[str, str]]) -> None:
def create_dir_structure(
root_dir: str, dir_names: dict[str, dict[str, str]]
) -> None:
# Create the directory structure
for split in splits:
for i in range(DUMMY_DATA_SIZE[split]):
Expand Down
40 changes: 34 additions & 6 deletions torchgeo/datasets/flair2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,28 @@ class FLAIR2(NonGeoDataset):

# Band information
aerial_rgb_bands: tuple[str, str, str] = ('B01', 'B02', 'B03')
aerial_all_bands: tuple[str, str, str, str, str] = ('B01', 'B02', 'B03', 'B04', 'B05')
aerial_all_bands: tuple[str, str, str, str, str] = (
'B01',
'B02',
'B03',
'B04',
'B05',
)
sentinel_rgb_bands: tuple[str, str, str] = ('B03', 'B02', 'B01')
# Order refers to 2, 3, 4, 5, 6, 7, 8, 8A, 11, 12 as described in the dataset paper
sentinel_all_bands: tuple[str, ...] = ('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B08', 'B09', 'B10')
sentinel_all_bands: tuple[str, ...] = (
'B01',
'B02',
'B03',
'B04',
'B05',
'B06',
'B07',
'B08',
'B08',
'B09',
'B10',
)

# Note: the original dataset contains 18 classes, but the dataset paper suggests
# grouping all classes >13 into "other" class, due to underrepresentation
Expand Down Expand Up @@ -229,7 +247,7 @@ def __init__(
download: bool = False,
checksum: bool = False,
use_sentinel: bool = False,
sentinel_bands: Sequence[str] = sentinel_all_bands
sentinel_bands: Sequence[str] = sentinel_all_bands,
) -> None:
"""Initialize a new FLAIR2 dataset instance.
Expand Down Expand Up @@ -272,7 +290,11 @@ def get_num_bands(self, include_sentinel_bands: bool = False) -> int:
Returns:
int: number of bands in the initialized dataset (might vary from all_bands)
"""
return len(self.aerial_bands) if not include_sentinel_bands else len(self.aerial_bands) + len(self.sentinel_bands)
return (
len(self.aerial_bands)
if not include_sentinel_bands
else len(self.aerial_bands) + len(self.sentinel_bands)
)

def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Expand Down Expand Up @@ -576,7 +598,9 @@ def normalize_plot(tensor: Tensor) -> Tensor:
"""Normalize the plot."""
return (tensor - tensor.min()) / (tensor.max() - tensor.min())

rgb_indices = [self.aerial_all_bands.index(band) for band in self.aerial_rgb_bands]
rgb_indices = [
self.aerial_all_bands.index(band) for band in self.aerial_rgb_bands
]
# Check if RGB bands are present in self.bands
if not all([band in self.aerial_bands for band in self.aerial_rgb_bands]):
raise RGBBandsMissingError()
Expand All @@ -588,7 +612,11 @@ def normalize_plot(tensor: Tensor) -> Tensor:
if 'B05' in self.aerial_bands:
elevation = sample['image'][self.aerial_bands.index('B05')]
if 'B04' in self.aerial_bands:
nir_r_g_indices = [self.aerial_bands.index('B04'), rgb_indices[0], rgb_indices[1]]
nir_r_g_indices = [
self.aerial_bands.index('B04'),
rgb_indices[0],
rgb_indices[1],
]
nir_r_g = normalize_plot(sample['image'][nir_r_g_indices].permute(1, 2, 0))

# Sentinel is a time-series, i.e. use [0]->T=0
Expand Down

0 comments on commit f49b84f

Please sign in to comment.