Skip to content

Commit

Permalink
So2Sat hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 22, 2023
1 parent fdeeae9 commit 6c80d70
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 10 additions & 2 deletions torchgeo/datamodules/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def __init__(
.. versionadded:: 0.5
The *val_split_pct* parameter, and the 'rgb' argument to *band_set*.
"""
version = kwargs.get("version", "2")
if "version" in kwargs:
kwargs["version"] = str(kwargs["version"])
version = kwargs["version"]
else:
version = "2"
kwargs["bands"] = So2Sat.BAND_SETS[band_set]
self.val_split_pct = val_split_pct

Expand All @@ -209,7 +213,11 @@ def setup(self, stage: str) -> None:
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if self.kwargs.get("version", "2") == "2":
if "version" in self.kwargs:
version = self.kwargs["version"]
else:
version = "2"
if version == "2":
if stage in ["fit"]:
self.train_dataset = So2Sat(split="train", **self.kwargs)
if stage in ["fit", "validate"]:
Expand Down
2 changes: 0 additions & 2 deletions torchgeo/datasets/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ def __init__(
raise ImportError(
"h5py is not installed and is required to use this dataset"
)

version = str(version)
assert version in self.versions
assert split in self.filenames_by_version[version]

Expand Down

0 comments on commit 6c80d70

Please sign in to comment.