diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000000..193edeef897 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,21 @@ +# TorchGeo modules +datamodules: +- torchgeo/datamodules/** +datasets: +- torchgeo/datasets/** +losses: +- torchgeo/losses/** +models: +- torchgeo/models/** +samplers: +- torchgeo/samplers/** +trainers: +- torchgeo/trainers/** +transforms: +- torchgeo/transforms/** + +# Other +documentation: +- docs/** +testing: +- tests/** diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000000..4e3daee0c33 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,12 @@ +name: "labeler" +on: +- pull_request_target + +jobs: + label: + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v3 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + sync-labels: true diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index fc981f1073b..a4e23d43d9a 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -7,8 +7,8 @@ on: branches: - release** jobs: - notebooks: - name: notebooks + datasets: + name: datasets runs-on: ubuntu-latest steps: - name: Clone repo @@ -19,12 +19,10 @@ jobs: python-version: 3.9 - name: Install pip dependencies run: | - pip install .[datasets,tests] - pip install -r docs/requirements.txt - - name: Run notebook checks - env: - MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }} - run: pytest --nbmake docs/tutorials + pip install cython numpy # needed for pycocotools + pip install .[tests] + - name: Run pytest checks + run: pytest --cov=torchgeo --cov-report=xml integration: name: integration runs-on: ubuntu-latest @@ -39,3 +37,21 @@ jobs: run: pip install .[datasets,tests] - name: Run integration checks run: pytest -m slow + notebooks: + name: notebooks + runs-on: ubuntu-latest + steps: + - name: Clone repo + uses: actions/checkout@v2 + - name: Set up python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install pip dependencies + run: | + pip install .[datasets,tests] + pip install -r docs/requirements.txt + - name: Run notebook checks + env: + MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }} + run: pytest --nbmake docs/tutorials diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4154f48cd51..926a1dee9b3 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -25,22 +25,6 @@ jobs: pip install .[datasets,tests] - name: Run mypy checks run: mypy . - datasets: - name: datasets - runs-on: ubuntu-latest - steps: - - name: Clone repo - uses: actions/checkout@v2 - - name: Set up python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - name: Install pip dependencies - run: | - pip install cython numpy # needed for pycocotools - pip install .[tests] - - name: Run pytest checks - run: pytest --cov=torchgeo --cov-report=xml pytest: name: pytest runs-on: ${{ matrix.os }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81545d56034..0a439c88f5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.8.0 + rev: 5.10.0 hooks: - id: isort additional_dependencies: ["colorama>=0.4.3"] - repo: https://github.com/psf/black - rev: 21.4b0 + rev: 21.12b0 hooks: - id: black args: [--skip-magic-trailing-comma] @@ -24,7 +24,7 @@ repos: additional_dependencies: ["toml"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.920 + rev: v0.930 hooks: - id: mypy args: [--strict, --ignore-missing-imports, --show-error-codes] diff --git a/conf/defaults.yaml b/conf/defaults.yaml index ca50a41571e..1ccc0c9168f 100644 --- a/conf/defaults.yaml +++ b/conf/defaults.yaml @@ -1,7 +1,7 @@ config_file: null # This lets the user pass a config filename to load other arguments from program: # These are the arguments that define how the train.py script works - seed: 1337 + seed: 0 output_dir: output data_dir: data log_dir: logs @@ -16,16 +16,17 @@ experiment: # These are arugments specific to the experiment we are running root_dir: ${program.data_dir} seed: ${program.seed} batch_size: 32 - num_workers: 4 + num_workers: 0 # The values here are taken from the defaults here https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init # this probably should be made into a schema, e.g. as shown https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs trainer: # These are the parameters passed to the pytorch lightning Trainer object logger: True - checkpoint_callback: True callbacks: null default_root_dir: null + detect_anomaly: False + enable_checkpointing: True gradient_clip_val: 0.0 gradient_clip_algorithm: 'norm' process_position: 0 @@ -43,7 +44,7 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje accumulate_grad_batches: 1 max_epochs: null min_epochs: null - max_steps: null + max_steps: -1 min_steps: null max_time: null limit_train_batches: 1.0 @@ -51,8 +52,7 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje limit_test_batches: 1.0 limit_predict_batches: 1.0 val_check_interval: 1.0 - flush_logs_every_n_steps: 100 - log_every_n_steps: 50 + log_every_n_steps: 1 accelerator: null sync_batchnorm: False precision: 32 @@ -66,9 +66,7 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje reload_dataloaders_every_epoch: False auto_lr_find: False replace_sampler_ddp: True - terminate_on_nan: False auto_scale_batch_size: False - prepare_data_per_node: True plugins: null amp_backend: 'native' move_metrics_to_cpu: False diff --git a/conf/task_defaults/bigearthnet.yaml b/conf/task_defaults/bigearthnet_all.yaml similarity index 94% rename from conf/task_defaults/bigearthnet.yaml rename to conf/task_defaults/bigearthnet_all.yaml index c5c352861ba..ba55f85e709 100644 --- a/conf/task_defaults/bigearthnet.yaml +++ b/conf/task_defaults/bigearthnet_all.yaml @@ -12,5 +12,5 @@ experiment: root_dir: "tests/data/bigearthnet" bands: "all" num_classes: ${experiment.module.num_classes} - batch_size: 128 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/bigearthnet_s1.yaml b/conf/task_defaults/bigearthnet_s1.yaml new file mode 100644 index 00000000000..73e727a04ea --- /dev/null +++ b/conf/task_defaults/bigearthnet_s1.yaml @@ -0,0 +1,16 @@ +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: "random" + in_channels: 2 + num_classes: 19 + datamodule: + root_dir: "tests/data/bigearthnet" + bands: "s1" + num_classes: ${experiment.module.num_classes} + batch_size: 1 + num_workers: 0 diff --git a/conf/task_defaults/bigearthnet_s2.yaml b/conf/task_defaults/bigearthnet_s2.yaml new file mode 100644 index 00000000000..9b9983a461d --- /dev/null +++ b/conf/task_defaults/bigearthnet_s2.yaml @@ -0,0 +1,16 @@ +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: "random" + in_channels: 12 + num_classes: 19 + datamodule: + root_dir: "tests/data/bigearthnet" + bands: "s2" + num_classes: ${experiment.module.num_classes} + batch_size: 1 + num_workers: 0 diff --git a/conf/task_defaults/byol.yaml b/conf/task_defaults/byol.yaml index 90eefbc96f6..d79b0def86c 100644 --- a/conf/task_defaults/byol.yaml +++ b/conf/task_defaults/byol.yaml @@ -16,5 +16,5 @@ experiment: - "de-test" test_splits: - "de-test" - batch_size: 64 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/chesapeake_cvpr_5.yaml b/conf/task_defaults/chesapeake_cvpr_5.yaml new file mode 100644 index 00000000000..63b0b469f42 --- /dev/null +++ b/conf/task_defaults/chesapeake_cvpr_5.yaml @@ -0,0 +1,29 @@ +experiment: + task: "chesapeake_cvpr" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet50" + encoder_weights: null + encoder_output_stride: 16 + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 4 + num_classes: 5 + num_filters: 1 + ignore_zeros: False + imagenet_pretraining: False + datamodule: + root_dir: "tests/data/chesapeake/cvpr" + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + patches_per_tile: 2 + patch_size: 64 + batch_size: 2 + num_workers: 0 + class_set: ${experiment.module.num_classes} + use_prior_labels: False diff --git a/conf/task_defaults/chesapeake_cvpr.yaml b/conf/task_defaults/chesapeake_cvpr_7.yaml similarity index 79% rename from conf/task_defaults/chesapeake_cvpr.yaml rename to conf/task_defaults/chesapeake_cvpr_7.yaml index b210d592636..b1cd0bde844 100644 --- a/conf/task_defaults/chesapeake_cvpr.yaml +++ b/conf/task_defaults/chesapeake_cvpr_7.yaml @@ -10,8 +10,9 @@ experiment: learning_rate_schedule_patience: 6 in_channels: 4 num_classes: 7 - num_filters: 256 + num_filters: 1 ignore_zeros: False + imagenet_pretraining: False datamodule: root_dir: "tests/data/chesapeake/cvpr" train_splits: @@ -20,8 +21,9 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 200 - patch_size: 256 - batch_size: 64 + patches_per_tile: 2 + patch_size: 64 + batch_size: 2 num_workers: 0 class_set: ${experiment.module.num_classes} + use_prior_labels: False diff --git a/conf/task_defaults/chesapeake_cvpr_prior.yaml b/conf/task_defaults/chesapeake_cvpr_prior.yaml new file mode 100644 index 00000000000..ab7398da3b9 --- /dev/null +++ b/conf/task_defaults/chesapeake_cvpr_prior.yaml @@ -0,0 +1,29 @@ +experiment: + task: "chesapeake_cvpr" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet50" + encoder_weights: null + encoder_output_stride: 16 + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 4 + num_classes: 5 + num_filters: 1 + ignore_zeros: False + imagenet_pretraining: False + datamodule: + root_dir: "tests/data/chesapeake/cvpr" + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + patches_per_tile: 2 + patch_size: 64 + batch_size: 2 + num_workers: 0 + class_set: ${experiment.module.num_classes} + use_prior_labels: True diff --git a/conf/task_defaults/cowc_counting.yaml b/conf/task_defaults/cowc_counting.yaml index a0d43a23462..4a4ab2f9abb 100644 --- a/conf/task_defaults/cowc_counting.yaml +++ b/conf/task_defaults/cowc_counting.yaml @@ -4,7 +4,9 @@ experiment: model: resnet18 learning_rate: 1e-3 learning_rate_schedule_patience: 2 + pretrained: False datamodule: root_dir: "tests/data/cowc_counting" - batch_size: 32 + seed: 0 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/cyclone.yaml b/conf/task_defaults/cyclone.yaml index 3e2fe918094..ba3d039f9d7 100644 --- a/conf/task_defaults/cyclone.yaml +++ b/conf/task_defaults/cyclone.yaml @@ -7,5 +7,6 @@ experiment: pretrained: False datamodule: root_dir: "tests/data/cyclone" - batch_size: 32 + seed: 0 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/etci2021.yaml b/conf/task_defaults/etci2021.yaml index 132887630e2..880ac6232e6 100644 --- a/conf/task_defaults/etci2021.yaml +++ b/conf/task_defaults/etci2021.yaml @@ -12,5 +12,5 @@ experiment: ignore_zeros: True datamodule: root_dir: "tests/data/etci2021" - batch_size: 32 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/eurosat.yaml b/conf/task_defaults/eurosat.yaml index dd40b85f6c9..9161c2c0cae 100644 --- a/conf/task_defaults/eurosat.yaml +++ b/conf/task_defaults/eurosat.yaml @@ -10,5 +10,5 @@ experiment: num_classes: 10 datamodule: root_dir: "tests/data/eurosat" - batch_size: 128 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/landcoverai.yaml b/conf/task_defaults/landcoverai.yaml index 72e2f024999..4e28b018935 100644 --- a/conf/task_defaults/landcoverai.yaml +++ b/conf/task_defaults/landcoverai.yaml @@ -10,9 +10,9 @@ experiment: verbose: false in_channels: 3 num_classes: 6 - num_filters: 256 + num_filters: 1 ignore_zeros: False datamodule: root_dir: "tests/data/landcoverai" - batch_size: 32 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/naipchesapeake.yaml b/conf/task_defaults/naipchesapeake.yaml index e3546ad6ff1..83814f3316a 100644 --- a/conf/task_defaults/naipchesapeake.yaml +++ b/conf/task_defaults/naipchesapeake.yaml @@ -10,11 +10,11 @@ experiment: learning_rate_schedule_patience: 2 in_channels: 4 num_classes: 13 - num_filters: 64 + num_filters: 1 ignore_zeros: False datamodule: naip_root_dir: "tests/data/naip" chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE" - batch_size: 32 + batch_size: 2 num_workers: 0 patch_size: 32 diff --git a/conf/task_defaults/oscd.yaml b/conf/task_defaults/oscd_all.yaml similarity index 74% rename from conf/task_defaults/oscd.yaml rename to conf/task_defaults/oscd_all.yaml index 5ae3fdccb78..9cd70b45daa 100644 --- a/conf/task_defaults/oscd.yaml +++ b/conf/task_defaults/oscd_all.yaml @@ -4,18 +4,18 @@ experiment: loss: "jaccard" segmentation_model: "unet" encoder_name: "resnet18" - encoder_weights: null + encoder_weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 verbose: false in_channels: 26 num_classes: 2 - num_filters: 256 + num_filters: 1 ignore_zeros: True datamodule: root_dir: "tests/data/oscd" - batch_size: 32 + batch_size: 1 num_workers: 0 - val_split_pct: 0.1 + val_split_pct: 0.5 bands: "all" - num_patches_per_tile: 128 + num_patches_per_tile: 1 diff --git a/conf/task_defaults/oscd_rgb.yaml b/conf/task_defaults/oscd_rgb.yaml new file mode 100644 index 00000000000..bf24bd4b02a --- /dev/null +++ b/conf/task_defaults/oscd_rgb.yaml @@ -0,0 +1,21 @@ +experiment: + task: "oscd" + module: + loss: "jaccard" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 6 + num_classes: 2 + num_filters: 1 + ignore_zeros: True + datamodule: + root_dir: "tests/data/oscd" + batch_size: 1 + num_workers: 0 + val_split_pct: 0.5 + bands: "rgb" + num_patches_per_tile: 1 diff --git a/conf/task_defaults/resisc45.yaml b/conf/task_defaults/resisc45.yaml index a657cc5d02a..e95efe9af89 100644 --- a/conf/task_defaults/resisc45.yaml +++ b/conf/task_defaults/resisc45.yaml @@ -10,5 +10,5 @@ experiment: num_classes: 45 datamodule: root_dir: "tests/data/resisc45" - batch_size: 128 + batch_size: 1 num_workers: 0 diff --git a/conf/task_defaults/sen12ms.yaml b/conf/task_defaults/sen12ms_all.yaml similarity index 87% rename from conf/task_defaults/sen12ms.yaml rename to conf/task_defaults/sen12ms_all.yaml index e4a946a23f8..1a0f73fddd4 100644 --- a/conf/task_defaults/sen12ms.yaml +++ b/conf/task_defaults/sen12ms_all.yaml @@ -13,5 +13,7 @@ experiment: ignore_zeros: False datamodule: root_dir: "tests/data/sen12ms" - batch_size: 32 + band_set: "all" + batch_size: 1 num_workers: 0 + seed: 0 diff --git a/conf/task_defaults/sen12ms_s1.yaml b/conf/task_defaults/sen12ms_s1.yaml new file mode 100644 index 00000000000..a2fdbb17031 --- /dev/null +++ b/conf/task_defaults/sen12ms_s1.yaml @@ -0,0 +1,20 @@ +experiment: + task: "sen12ms" + module: + loss: "focal" + segmentation_model: "fcn" + num_filters: 1 + encoder_name: "resnet18" + encoder_weights: null + encoder_output_stride: 16 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 2 + num_classes: 11 + ignore_zeros: False + datamodule: + root_dir: "tests/data/sen12ms" + band_set: "s1" + batch_size: 1 + num_workers: 0 + seed: 0 diff --git a/conf/task_defaults/sen12ms_s2_all.yaml b/conf/task_defaults/sen12ms_s2_all.yaml new file mode 100644 index 00000000000..eb081ef722f --- /dev/null +++ b/conf/task_defaults/sen12ms_s2_all.yaml @@ -0,0 +1,19 @@ +experiment: + task: "sen12ms" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + encoder_output_stride: 16 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 13 + num_classes: 11 + ignore_zeros: False + datamodule: + root_dir: "tests/data/sen12ms" + band_set: "s2-all" + batch_size: 1 + num_workers: 0 + seed: 0 diff --git a/conf/task_defaults/sen12ms_s2_reduced.yaml b/conf/task_defaults/sen12ms_s2_reduced.yaml new file mode 100644 index 00000000000..e44c20a3dbd --- /dev/null +++ b/conf/task_defaults/sen12ms_s2_reduced.yaml @@ -0,0 +1,19 @@ +experiment: + task: "sen12ms" + module: + loss: "ce" + segmentation_model: "unet" + encoder_name: "resnet18" + encoder_weights: null + encoder_output_stride: 16 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 6 + num_classes: 11 + ignore_zeros: False + datamodule: + root_dir: "tests/data/sen12ms" + band_set: "s2-reduced" + batch_size: 1 + num_workers: 0 + seed: 0 diff --git a/conf/task_defaults/so2sat.yaml b/conf/task_defaults/so2sat_supervised.yaml similarity index 81% rename from conf/task_defaults/so2sat.yaml rename to conf/task_defaults/so2sat_supervised.yaml index be6a4e0f5c9..0f215c24c84 100644 --- a/conf/task_defaults/so2sat.yaml +++ b/conf/task_defaults/so2sat_supervised.yaml @@ -1,7 +1,7 @@ experiment: task: "so2sat" module: - loss: "ce" + loss: "focal" classification_model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 @@ -10,6 +10,7 @@ experiment: num_classes: 17 datamodule: root_dir: "tests/data/so2sat" - batch_size: 128 + batch_size: 1 num_workers: 0 bands: "rgb" + unsupervised_mode: False diff --git a/conf/task_defaults/so2sat_unsupervised.yaml b/conf/task_defaults/so2sat_unsupervised.yaml new file mode 100644 index 00000000000..ec51d18c4a6 --- /dev/null +++ b/conf/task_defaults/so2sat_unsupervised.yaml @@ -0,0 +1,16 @@ +experiment: + task: "so2sat" + module: + loss: "jaccard" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: "random" + in_channels: 3 + num_classes: 17 + datamodule: + root_dir: "tests/data/so2sat" + batch_size: 1 + num_workers: 0 + bands: "rgb" + unsupervised_mode: True diff --git a/conf/task_defaults/ucmerced.yaml b/conf/task_defaults/ucmerced.yaml index 2742fa329c8..31f7dba2960 100644 --- a/conf/task_defaults/ucmerced.yaml +++ b/conf/task_defaults/ucmerced.yaml @@ -10,5 +10,5 @@ experiment: num_classes: 21 datamodule: root_dir: "tests/data/ucmerced" - batch_size: 128 + batch_size: 1 num_workers: 0 diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst new file mode 100644 index 00000000000..ad47f9cd648 --- /dev/null +++ b/docs/api/datamodules.rst @@ -0,0 +1,105 @@ +torchgeo.datamodules +==================== + +.. module:: torchgeo.datamodules + +Geospatial DataModules +---------------------- + +Chesapeake Bay High-Resolution Land Cover Project +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: ChesapeakeCVPRDataModule + +National Agriculture Imagery Program (NAIP) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: NAIPChesapeakeDataModule + +Non-geospatial DataModules +-------------------------- + +BigEarthNet +^^^^^^^^^^^ + +.. autoclass:: BigEarthNetDataModule + +Cars Overhead With Context (COWC) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: COWCCountingDataModule + +ETCI2021 Flood Detection +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: ETCI2021DataModule + +EuroSAT +^^^^^^^ + +.. autoclass:: EuroSATDataModule + +FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: FAIR1MDataModule + +LandCover.ai (Land Cover from Aerial Imagery) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: LandCoverAIDataModule + +LoveDA (Land-cOVEr Domain Adaptive semantic segmentation) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: LoveDADataModule + +NASA Marine Debris +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: NASAMarineDebrisDataModule + +OSCD (Onera Satellite Change Detection) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: OSCDDataModule + +Potsdam +^^^^^^^ + +.. autoclass:: Potsdam2DDataModule + +RESISC45 (Remote Sensing Image Scene Classification) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RESISC45DataModule + +SEN12MS +^^^^^^^ + +.. autoclass:: SEN12MSDataModule + +So2Sat +^^^^^^ + +.. autoclass:: So2SatDataModule + +Tropical Cyclone Wind Estimation Competition +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: CycloneDataModule + +UC Merced +^^^^^^^^^ + +.. autoclass:: UCMercedDataModule + +Vaihingen +^^^^^^^^^ + +.. autoclass:: Vaihingen2DDataModule + +xView2 +^^^^^^ + +.. autoclass:: XView2DataModule diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 4b30c762a76..f3df706a8b4 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -31,7 +31,6 @@ Chesapeake Bay High-Resolution Land Cover Project .. autoclass:: ChesapeakeVA .. autoclass:: ChesapeakeWV .. autoclass:: ChesapeakeCVPR -.. autoclass:: ChesapeakeCVPRDataModule Cropland Data Layer (CDL) ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -57,7 +56,6 @@ National Agriculture Imagery Program (NAIP) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: NAIP -.. autoclass:: NAIPChesapeakeDataModule Sentinel ^^^^^^^^ @@ -86,7 +84,6 @@ BigEarthNet ^^^^^^^^^^^ .. autoclass:: BigEarthNet -.. autoclass:: BigEarthNetDataModule Cars Overhead With Context (COWC) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -94,7 +91,6 @@ Cars Overhead With Context (COWC) .. autoclass:: COWC .. autoclass:: COWCCounting .. autoclass:: COWCDetection -.. autoclass:: COWCCountingDataModule CV4A Kenya Crop Type Competition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -105,19 +101,16 @@ ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: ETCI2021 -.. autoclass:: ETCI2021DataModule EuroSAT -^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^ .. autoclass:: EuroSAT -.. autoclass:: EuroSATDataModule FAIR1M (Fine-grAined object recognItion in high-Resolution imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: FAIR1M -.. autoclass:: FAIR1MDataModule GID-15 (Gaofen Image Dataset) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -133,7 +126,6 @@ LandCover.ai (Land Cover from Aerial Imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: LandCoverAI -.. autoclass:: LandCoverAIDataModule LEVIR-CD+ (LEVIR Change Detection +) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -144,19 +136,16 @@ LoveDA (Land-cOVEr Domain Adaptive semantic segmentation) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: LoveDA -.. autoclass:: LoveDADataModule NASA Marine Debris ^^^^^^^^^^^^^^^^^^ .. autoclass:: NASAMarineDebris -.. autoclass:: NASAMarineDebrisDataModule OSCD (Onera Satellite Change Detection) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: OSCD -.. autoclass:: OSCDDataModule PatternNet ^^^^^^^^^^ @@ -167,13 +156,11 @@ Potsdam ^^^^^^^ .. autoclass:: Potsdam2D -.. autoclass:: Potsdam2DDataModule RESISC45 (Remote Sensing Image Scene Classification) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: RESISC45 -.. autoclass:: RESISC45DataModule Seasonal Contrast ^^^^^^^^^^^^^^^^^ @@ -184,13 +171,11 @@ SEN12MS ^^^^^^^ .. autoclass:: SEN12MS -.. autoclass:: SEN12MSDataModule So2Sat ^^^^^^ .. autoclass:: So2Sat -.. autoclass:: So2SatDataModule SpaceNet ^^^^^^^^ @@ -206,30 +191,26 @@ Tropical Cyclone Wind Estimation Competition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: TropicalCycloneWindEstimation -.. autoclass:: CycloneDataModule + +UC Merced +^^^^^^^^^ + +.. autoclass:: UCMerced Vaihingen ^^^^^^^^^ .. autoclass:: Vaihingen2D -.. autoclass:: Vaihingen2DDataModule NWPU VHR-10 ^^^^^^^^^^^ .. autoclass:: VHR10 -UC Merced -^^^^^^^^^ - -.. autoclass:: UCMerced -.. autoclass:: UCMercedDataModule - xView2 ^^^^^^ .. autoclass:: XView2 -.. autoclass:: XView2DataModule ZueriCrop ^^^^^^^^^ diff --git a/docs/api/losses.rst b/docs/api/losses.rst new file mode 100644 index 00000000000..3ce97d2e169 --- /dev/null +++ b/docs/api/losses.rst @@ -0,0 +1,4 @@ +torchgeo.losses +================= + +.. automodule:: torchgeo.losses diff --git a/docs/conf.py b/docs/conf.py index d6c87624ed7..55d89475124 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,8 +50,8 @@ exclude_patterns = ["_build"] # Sphinx 3.0+ required for: -# autodoc_typehints = "description" -needs_sphinx = "3.0" +# autodoc_typehints_description_target = "documented" +needs_sphinx = "4.0" nitpicky = True nitpick_ignore = [ @@ -98,6 +98,7 @@ } autodoc_member_order = "bysource" autodoc_typehints = "description" +autodoc_typehints_description_target = "documented" # sphinx.ext.intersphinx intersphinx_mapping = { diff --git a/docs/index.rst b/docs/index.rst index 296c7a35ea8..6228db3805b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,7 +15,9 @@ torchgeo :maxdepth: 2 :caption: Package Reference + api/datamodules api/datasets + api/losses api/models api/samplers api/trainers @@ -26,6 +28,7 @@ torchgeo :caption: Tutorials tutorials/getting_started + tutorials/custom_raster_dataset tutorials/transforms tutorials/indices tutorials/trainers diff --git a/docs/requirements.txt b/docs/requirements.txt index c4a5e359b39..0e7387bebe4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,5 +4,5 @@ ipywidgets>=7 nbsphinx>=0.8.5 # release versions missing files, must install from master -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -# sphinx 3+ required for autodoc_typehints = description -sphinx>=3 +# sphinx 4+ required for autodoc_typehints_description_target = documented +sphinx>=4 diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb new file mode 100644 index 00000000000..a90792f1303 --- /dev/null +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation. All rights reserved.\n", + "\n", + "Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Custom Raster Datasets\n", + "\n", + "In this tutorial, we demonstrate how to create a custom `RasterDataset` for our own data. We will use the [xView3](https://iuu.xview.us/) tiny dataset as an example." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from typing import Callable, Dict, Optional\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from rasterio.crs import CRS\n", + "from torch import Tensor\n", + "from torch.utils.data import DataLoader\n", + "from torchgeo.datasets import RasterDataset, stack_samples\n", + "from torchgeo.samplers import RandomGeoSampler\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom RasterDataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Unzipping the sample xView3 data from the tests folder" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from torchgeo.datasets.utils import extract_archive\n", + "\n", + "data_root = Path('../../tests/data/xview3/')\n", + "extract_archive(str(data_root / 'sample_data.tar.gz'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have the xView3 tiny dataset downloaded and unzipped in our local directory. Note that all the test GeoTIFFs are comprised entirely of zeros. Any plotted image will appear to be entirely uniform.\n", + "\n", + " xview3\n", + " ├── 05bc615a9b0aaaaaa\n", + " │ ├── bathymetry.tif\n", + " │ ├── owiMask.tif\n", + " │ ├── owiWindDirection.tif\n", + " │ ├── owiWindQuality.tif\n", + " │ ├── owiWindSpeed.tif\n", + " │ ├── VH_dB.tif\n", + " │ └── VV_dB.tif\n", + "\n", + "We would like to create a custom Dataset class based off of RasterDataset for this xView3 data. This will let us use `torchgeo` features such as: random sampling, merging other layers, fusing multiple datasets with `UnionDataset` and `IntersectionDataset`, and more. To do this, we can simply subclass `RasterDataset` and define a `filename_glob` property to select which files in a root directory will be included in the dataset. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class XView3Polarizations(RasterDataset):\n", + " '''\n", + " Load xView3 polarization data that ends in *_dB.tif\n", + " '''\n", + "\n", + " filename_glob = \"*_dB.tif\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1, 102, 102])\n", + "torch.Size([1, 1, 102, 102])\n", + "torch.Size([1, 1, 102, 102])\n", + "torch.Size([1, 1, 102, 102])\n", + "torch.Size([1, 1, 102, 102])\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAD7CAYAAACWhwr8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAL70lEQVR4nO3cYaydBX3H8e9vreiAGNtxIZXCikmjMhOHuXEgy2KobOiM5Q0JJizNRtI3bqIxMWW8IHvnC2P0xWbSgNpMAiFIbEOcSKrG7A3jImQDCpaJK5XaXmemiy8cxP9enGfjrLn/tT3nnp5z4/eTNM95nnOec/5cer95ztPznFQVkrSW35r3AJIWl4GQ1DIQkloGQlLLQEhqGQhJrZkFIslNSV5I8mKSfbN6HUmzk1l8DiLJJuAHwI3AceAJ4KNV9dy6v5ikmdk8o+d9L/BiVf0QIMkDwG5gzUBccskltWPHjhmNIgngySef/GlVLZ3LPrMKxOXAy2Prx4E/GH9Akr3AXoArr7ySlZWVGY0iCSDJv53rPrM6B5E1tv2f9zJVtb+qlqtqeWnpnKIm6TyZVSCOA1eMrW8HXpnRa0makVkF4glgZ5KrklwA3AocmtFrSZqRmZyDqKrXkvwl8CiwCfhSVT07i9eSNDuzOklJVX0D+Masnl/S7PlJSkktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISS0DIallICS1DISkloGQ1DIQkloGQlLLQEhqGQhJLQMhqWUgJLUMhKSWgZDUMhCSWgZCUstASGoZCEktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISS0DIak1cSCSXJHkO0mOJHk2yR3D9q1JHktydFhuWb9xJZ1P0xxBvAZ8qqreCVwLfCzJ1cA+4HBV7QQOD+uSNqCJA1FVJ6rq+8Pt/wSOAJcDu4EDw8MOADdPOaOkOVmXcxBJdgDXAI8Dl1XVCRhFBLi02WdvkpUkK6urq+sxhqR1NnUgklwMfA34RFX94mz3q6r9VbVcVctLS0vTjiFpBqYKRJI3MIrDfVX18LD5ZJJtw/3bgFPTjShpXqb5V4wA9wJHqupzY3cdAvYMt/cABycfT9I8bZ5i3+uBPwP+JcnTw7a/Bj4DPJjkduAYcMtUE0qam4kDUVX/CKS5e9ekzytpcfhJSkktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISS0DIallICS1DISkloGQ1DIQkloGQlLLQEhqGQhJLQMhqWUgJLUMhKSWgZDUMhCSWgZCUstASGoZCEktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISa2pA5FkU5KnkjwyrG9N8liSo8Nyy/RjSpqH9TiCuAM4Mra+DzhcVTuBw8O6pA1oqkAk2Q78KXDP2ObdwIHh9gHg5mleQ9L8THsE8Xng08Cvx7ZdVlUnAIblpVO+hqQ5mTgQST4MnKqqJyfcf2+SlSQrq6urk44haYamOYK4HvhIkh8BDwA3JPkqcDLJNoBheWqtnatqf1UtV9Xy0tLSFGNImpWJA1FVd1bV9qraAdwKfLuqbgMOAXuGh+0BDk49paS5mMXnID4D3JjkKHDjsC5pA9q8Hk9SVd8Fvjvc/ndg13o8r6T58pOUkloGQlLLQEhqGQhJLQMhqWUgJLUMhKSWgZDUMhCSWgZCUstASGoZCEktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISS0DIallICS1DISkloGQ1DIQkloGQlLLQEhqGQhJLQMhqWUgJLUMhKSWgZDUMhCSWgZCUmuqQCR5S5KHkjyf5EiS65JsTfJYkqPDcst6DSvp/Jr2COILwDer6h3Au4EjwD7gcFXtBA4P65I2oIkDkeTNwB8B9wJU1X9V1X8Au4EDw8MOADdPN6KkeZnmCOJtwCrw5SRPJbknyUXAZVV1AmBYXrrWzkn2JllJsrK6ujrFGJJmZZpAbAbeA3yxqq4Bfsk5vJ2oqv1VtVxVy0tLS1OMIWlWpgnEceB4VT0+rD/EKBgnk2wDGJanphtR0rxMHIiq+gnwcpK3D5t2Ac8Bh4A9w7Y9wMGpJpQ0N5un3P+vgPuSXAD8EPhzRtF5MMntwDHglilfQ9KcTBWIqnoaWF7jrl3TPK+kxeAnKSW1DISkloGQ1DIQkloGQlLLQEhqGQhJLQMhqWUgJLUMhKSWgZDUMhCSWgZCUstASGoZCEktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISS0DIallICS1DISkloGQ1DIQkloGQlLLQEhqGQhJLQMhqWUgJLUMhKTWVIFI8skkzyZ5Jsn9Sd6UZGuSx5IcHZZb1mtYSefXxIFIcjnwcWC5qt4FbAJuBfYBh6tqJ3B4WJe0AU37FmMz8NtJNgMXAq8Au4EDw/0HgJunfA1JczJxIKrqx8BngWPACeDnVfUt4LKqOjE85gRw6Vr7J9mbZCXJyurq6qRjSJqhad5ibGF0tHAV8FbgoiS3ne3+VbW/qparanlpaWnSMSTN0DRvMT4AvFRVq1X1KvAw8D7gZJJtAMPy1PRjSpqHaQJxDLg2yYVJAuwCjgCHgD3DY/YAB6cbUdK8bJ50x6p6PMlDwPeB14CngP3AxcCDSW5nFJFb1mNQSeffxIEAqKq7gbtP2/wrRkcTkjY4P0kpqWUgJLUMhKSWgZDUMhCSWgZCUstASGoZCEktAyGpZSAktQyEpJaBkNQyEJJaBkJSy0BIahkISS0DIallICS1DISkloGQ1DIQkloGQlLLQEhqGQhJLQMhqWUgJLUMhKSWgZDUMhCSWgZCUstASGoZCEktAyGpdcZAJPlSklNJnhnbtjXJY0mODsstY/fdmeTFJC8k+ZNZDS5p9s7mCOIrwE2nbdsHHK6qncDhYZ0kVwO3Ar837PN3STat27SSzqszBqKqvgf87LTNu4EDw+0DwM1j2x+oql9V1UvAi8B712dUSefbpOcgLquqEwDD8tJh++XAy2OPOz5sk7QBrfdJyqyxrdZ8YLI3yUqSldXV1XUeQ9J6mDQQJ5NsAxiWp4btx4Erxh63HXhlrSeoqv1VtVxVy0tLSxOOIWmWJg3EIWDPcHsPcHBs+61J3pjkKmAn8E/TjShpXlK15juA1x+Q3A+8H7gEOAncDXwdeBC4EjgG3FJVPxsefxfwF8BrwCeq6h/OOESyCvwS+OmE/x3n2yU463rbKHPCxp31d6vqnA7XzxiI8yXJSlUtz3uOs+Gs62+jzAm/WbP6SUpJLQMhqbVIgdg/7wHOgbOuv40yJ/wGzbow5yAkLZ5FOoKQtGAMhKTWQgQiyU3D5eEvJtk373n+R5IrknwnyZEkzya5Y9jeXu4+b0k2JXkqySPD+kLOmuQtSR5K8vzw871uEWdN8snh//0zSe5P8qZFmfN8fBXD3AMxXA7+t8AHgauBjw6XjS+C14BPVdU7gWuBjw2zrXm5+4K4Azgytr6os34B+GZVvQN4N6OZF2rWJJcDHweWq+pdwCZGX2ewKHN+hVl/FUNVzfUPcB3w6Nj6ncCd856rmfUgcCPwArBt2LYNeGHesw2zbB/+UtwAPDJsW7hZgTcDLzGcJB/bvlCz8vrVyVuBzcAjwB8v0pzADuCZM/0MT/+9Ah4FrjvT88/9CIINcol4kh3ANcDj9Je7z9vngU8Dvx7btoizvg1YBb48vB26J8lFLNisVfVj4LOMLic4Afy8qr7Fgs15mnX9KoZFCMRZXyI+L0kuBr7G6NqSX8x7nrUk+TBwqqqenPcsZ2Ez8B7gi1V1DaPrcBblrc//Gt6/7wauAt4KXJTktvlONbGJfs8WIRBnfYn4PCR5A6M43FdVDw+bu8vd5+l64CNJfgQ8ANyQ5Kss5qzHgeNV9fiw/hCjYCzarB8AXqqq1ap6FXgYeB+LN+e4qb+KYdwiBOIJYGeSq5JcwOhEyqE5zwRAkgD3Akeq6nNjd3WXu89NVd1ZVduragejn+G3q+o2FnPWnwAvJ3n7sGkX8ByLN+sx4NokFw5/F3YxOpm6aHOOW9+vYpjnSaCxEyYfAn4A/Ctw17znGZvrDxkdhv0z8PTw50PA7zA6GXh0WG6d96ynzf1+Xj9JuZCzAr8PrAw/268DWxZxVuBvgOeBZ4C/B964KHMC9zM6N/IqoyOE2/+/2YC7ht+xF4APns1r+FFrSa1FeIshaUEZCEktAyGpZSAktQyEpJaBkNQyEJJa/w0ISy7orPumHwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "ds = XView3Polarizations(data_root)\n", + "sampler = RandomGeoSampler(ds, size=1024, length=5)\n", + "dl = DataLoader(ds, sampler=sampler, collate_fn=stack_samples)\n", + "\n", + "for sample in dl:\n", + " image = sample['image']\n", + " print(image.shape)\n", + " image = torch.squeeze(image)\n", + " plt.imshow(image, cmap='bone', vmin=-35, vmax=-5)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "22a454fc14687a2143ab970e8915cf1cd36fe3c442d7a97f02ebf86977418cbe" + }, + "kernelspec": { + "display_name": "Python 3.7.11 64-bit ('overwatch': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/environment.yml b/environment.yml index c103271b80a..54a2762a936 100644 --- a/environment.yml +++ b/environment.yml @@ -43,6 +43,6 @@ dependencies: - scipy>=0.9 - segmentation-models-pytorch>=0.2 - setuptools>=42 - - sphinx>=3 + - sphinx>=4 - timm>=0.2.1 - torchmetrics diff --git a/evaluate.py b/evaluate.py index fdedf648b05..b8a7fa27023 100755 --- a/evaluate.py +++ b/evaluate.py @@ -14,8 +14,8 @@ import torch from torchmetrics import Accuracy, IoU, Metric, MetricCollection -from torchgeo import _TASK_TO_MODULES_MAPPING as TASK_TO_MODULES_MAPPING from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask +from train import TASK_TO_MODULES_MAPPING def set_up_parser() -> argparse.ArgumentParser: diff --git a/experiments/test_chesapeakecvpr_models.py b/experiments/test_chesapeakecvpr_models.py index 249dd20bc7d..38ab24fa842 100755 --- a/experiments/test_chesapeakecvpr_models.py +++ b/experiments/test_chesapeakecvpr_models.py @@ -11,7 +11,7 @@ import pytorch_lightning as pl import torch -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]] diff --git a/pyproject.toml b/pyproject.toml index cfb34131556..6f31702dce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ exclude_lines = [ [tool.isort] profile = "black" -known_first_party = ["docs", "tests", "torchgeo"] +known_first_party = ["docs", "tests", "torchgeo", "train"] extend_skip = [".spack-env/", "data", "logs", "output"] skip_gitignore = true color_output = true @@ -73,16 +73,19 @@ strict_equality = true [tool.pydocstyle] convention = "google" -match_dir = "(datasets|models|samplers|torchgeo|trainers|transforms)" +match_dir = "(datamodules|datasets|losses|models|samplers|torchgeo|trainers|transforms)" [tool.pytest.ini_options] # Skip slow tests by default addopts = "-m 'not slow'" filterwarnings = [ - "ignore:.*Create unlinked descriptors is going to go away.*:DeprecationWarning", + "ignore:.*Create unlinked descriptors is going to go away:DeprecationWarning", # https://github.com/tensorflow/tensorboard/pull/5138 - "ignore:.*is a deprecated alias for the builtin.*:DeprecationWarning", - "ignore:.*Previous behaviour produces incorrect box coordinates.*:UserWarning", + "ignore:.*is a deprecated alias for the builtin:DeprecationWarning", + "ignore:Previous behaviour produces incorrect box coordinates:UserWarning", + "ignore:The dataloader, .*, does not have many workers which may be a bottleneck:UserWarning", + "ignore:Your `.*_dataloader` has `shuffle=True`:UserWarning", + "ignore:Trying to infer the `batch_size` from an ambiguous collection:UserWarning", ] markers = [ "slow: marks tests as slow", diff --git a/tests/data/README.md b/tests/data/README.md index 3884cbf382a..9e37b4923d4 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -16,7 +16,8 @@ ROOT = "data/landsat8" FILENAME = "LC08_L2SP_023032_20210622_20210629_02_T1_SR_B1.TIF" src = rasterio.open(os.path.join(ROOT, FILENAME)) -Z = np.arange(4, dtype=src.read().dtype).reshape(2, 2) +dtype = src.read().dtype +Z = np.random.randint(np.iinfo(dtype).max, size=(64, 64), dtype=dtype) dst = rasterio.open(FILENAME, "w", driver=src.driver, height=Z.shape[0], width=Z.shape[1], count=src.count, dtype=Z.dtype, crs=src.crs, transform=src.transform) for i in range(1, dst.count + 1): dst.write(Z, i) @@ -52,18 +53,28 @@ VisionDataset data can be created like so. ### RGB images ```python +import numpy as np from PIL import Image -img = Image.new("RGB", (1, 1)) +DTYPE = np.uint8 +SIZE = 64 + +arr = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE, 3), dtype=DTYPE) +img = Image.fromarray(arr) img.save("01.png") ``` ### Grayscale images ```python +import numpy as np from PIL import Image -img = Image.new("L", (1, 1)) +DTYPE = np.uint8 +SIZE = 64 + +arr = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) +img = Image.fromarray(arr) img.save("02.jpg") ``` @@ -83,14 +94,15 @@ wavfile.write("01.wav", rate=22050, data=audio) import h5py import numpy as np -f = h5py.File("data.hdf5", "w") +DTYPE = np.uint8 +SIZE = 64 +NUM_CLASSES = 10 -num_classes = 10 -images = np.random.randint(low=0, high=255, size=(1, 1, 3)).astype(np.uint8) -masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8) -f.create_dataset("images", data=images) -f.create_dataset("masks", data=masks) -f.close() +images = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE, 3), dtype=DTYPE) +masks = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE), dtype=DTYPE) +with h5py.File("data.hdf5", "w") as f: + f.create_dataset("images", data=images) + f.create_dataset("masks", data=masks) ``` ### LAS Point Cloud files diff --git a/tests/data/chesapeake/cvpr/cvpr_chesapeake_landcover_prior_extension.zip b/tests/data/chesapeake/cvpr/cvpr_chesapeake_landcover_prior_extension.zip new file mode 100644 index 00000000000..4dc86ae73d6 Binary files /dev/null and b/tests/data/chesapeake/cvpr/cvpr_chesapeake_landcover_prior_extension.zip differ diff --git a/tests/data/chesapeake/cvpr/de_1m_2013_extended-debuffered-test_tiles/m_3807504_ne_18_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif b/tests/data/chesapeake/cvpr/de_1m_2013_extended-debuffered-test_tiles/m_3807504_ne_18_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif new file mode 100644 index 00000000000..57cd2d378d1 Binary files /dev/null and b/tests/data/chesapeake/cvpr/de_1m_2013_extended-debuffered-test_tiles/m_3807504_ne_18_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif differ diff --git a/tests/data/cowc_counting/COWC_Counting_Columbus_CSUAV_AFRL.tbz b/tests/data/cowc_counting/COWC_Counting_Columbus_CSUAV_AFRL.tbz index 92dfe682ca0..fc9b54ec201 100644 Binary files a/tests/data/cowc_counting/COWC_Counting_Columbus_CSUAV_AFRL.tbz and b/tests/data/cowc_counting/COWC_Counting_Columbus_CSUAV_AFRL.tbz differ diff --git a/tests/data/cowc_counting/COWC_Counting_Potsdam_ISPRS.tbz b/tests/data/cowc_counting/COWC_Counting_Potsdam_ISPRS.tbz index 1a5dd77ed16..29e7d83c121 100644 Binary files a/tests/data/cowc_counting/COWC_Counting_Potsdam_ISPRS.tbz and b/tests/data/cowc_counting/COWC_Counting_Potsdam_ISPRS.tbz differ diff --git a/tests/data/cowc_counting/COWC_Counting_Selwyn_LINZ.tbz b/tests/data/cowc_counting/COWC_Counting_Selwyn_LINZ.tbz index 258484e8e2a..9e897ebd2d9 100644 Binary files a/tests/data/cowc_counting/COWC_Counting_Selwyn_LINZ.tbz and b/tests/data/cowc_counting/COWC_Counting_Selwyn_LINZ.tbz differ diff --git a/tests/data/cowc_counting/COWC_Counting_Toronto_ISPRS.tbz b/tests/data/cowc_counting/COWC_Counting_Toronto_ISPRS.tbz index c3594d8f5b6..37884f8f77b 100644 Binary files a/tests/data/cowc_counting/COWC_Counting_Toronto_ISPRS.tbz and b/tests/data/cowc_counting/COWC_Counting_Toronto_ISPRS.tbz differ diff --git a/tests/data/cowc_counting/COWC_Counting_Utah_AGRC.tbz b/tests/data/cowc_counting/COWC_Counting_Utah_AGRC.tbz index 6eedeae9f68..c425862bf44 100644 Binary files a/tests/data/cowc_counting/COWC_Counting_Utah_AGRC.tbz and b/tests/data/cowc_counting/COWC_Counting_Utah_AGRC.tbz differ diff --git a/tests/data/cowc_counting/COWC_Counting_Vaihingen_ISPRS.tbz b/tests/data/cowc_counting/COWC_Counting_Vaihingen_ISPRS.tbz index 27058dd4d92..a1165019fe5 100644 Binary files a/tests/data/cowc_counting/COWC_Counting_Vaihingen_ISPRS.tbz and b/tests/data/cowc_counting/COWC_Counting_Vaihingen_ISPRS.tbz differ diff --git a/tests/data/cowc_counting/COWC_test_list_64_class.txt b/tests/data/cowc_counting/COWC_test_list_64_class.txt index 60ae5328dfd..c46f34bc023 100644 --- a/tests/data/cowc_counting/COWC_test_list_64_class.txt +++ b/tests/data/cowc_counting/COWC_test_list_64_class.txt @@ -1,6 +1,6 @@ -Toronto_ISPRS/test/fake_01.png 0 -Selwyn_LINZ/test/fake_03.png 1 -Potsdam_ISPRS/test/fake_05.png 2 -Vaihingen_ISPRS/test/fake_07.png 0 -Columbus_CSUAV_AFRL/test/fake_09.png 1 -Utah_AGRC/test/fake_11.png 11 +Toronto_ISPRS/test/fake_01.png 12 +Selwyn_LINZ/test/fake_03.png 8 +Potsdam_ISPRS/test/fake_05.png 12 +Vaihingen_ISPRS/test/fake_07.png 11 +Columbus_CSUAV_AFRL/test/fake_09.png 16 +Utah_AGRC/test/fake_11.png 4 diff --git a/tests/data/cowc_counting/COWC_test_list_64_class.txt.bz2 b/tests/data/cowc_counting/COWC_test_list_64_class.txt.bz2 index ac59f5a5a96..95e8e5aed23 100644 Binary files a/tests/data/cowc_counting/COWC_test_list_64_class.txt.bz2 and b/tests/data/cowc_counting/COWC_test_list_64_class.txt.bz2 differ diff --git a/tests/data/cowc_counting/COWC_train_list_64_class.txt b/tests/data/cowc_counting/COWC_train_list_64_class.txt index 60d877ce267..1fb05797843 100644 --- a/tests/data/cowc_counting/COWC_train_list_64_class.txt +++ b/tests/data/cowc_counting/COWC_train_list_64_class.txt @@ -1,12 +1,12 @@ -Toronto_ISPRS/train/fake_01.png 0 -Toronto_ISPRS/train/fake_02.png 3 -Selwyn_LINZ/train/fake_03.png 1 -Selwyn_LINZ/train/fake_04.png 0 -Potsdam_ISPRS/train/fake_05.png 2 -Potsdam_ISPRS/train/fake_06.png 2 -Vaihingen_ISPRS/train/fake_07.png 0 -Vaihingen_ISPRS/train/fake_08.png 0 -Columbus_CSUAV_AFRL/train/fake_09.png 1 -Columbus_CSUAV_AFRL/train/fake_10.png 1 -Utah_AGRC/train/fake_11.png 11 -Utah_AGRC/train/fake_12.png 12 +Toronto_ISPRS/train/fake_01.png 13 +Toronto_ISPRS/train/fake_02.png 1 +Selwyn_LINZ/train/fake_03.png 16 +Selwyn_LINZ/train/fake_04.png 15 +Potsdam_ISPRS/train/fake_05.png 9 +Potsdam_ISPRS/train/fake_06.png 15 +Vaihingen_ISPRS/train/fake_07.png 18 +Vaihingen_ISPRS/train/fake_08.png 6 +Columbus_CSUAV_AFRL/train/fake_09.png 4 +Columbus_CSUAV_AFRL/train/fake_10.png 9 +Utah_AGRC/train/fake_11.png 3 +Utah_AGRC/train/fake_12.png 19 diff --git a/tests/data/cowc_counting/COWC_train_list_64_class.txt.bz2 b/tests/data/cowc_counting/COWC_train_list_64_class.txt.bz2 index 57375a57249..15f5e2f6eb3 100644 Binary files a/tests/data/cowc_counting/COWC_train_list_64_class.txt.bz2 and b/tests/data/cowc_counting/COWC_train_list_64_class.txt.bz2 differ diff --git a/tests/data/cowc_counting/Columbus_CSUAV_AFRL/test/fake_09.png b/tests/data/cowc_counting/Columbus_CSUAV_AFRL/test/fake_09.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Columbus_CSUAV_AFRL/test/fake_09.png and b/tests/data/cowc_counting/Columbus_CSUAV_AFRL/test/fake_09.png differ diff --git a/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_09.png b/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_09.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_09.png and b/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_09.png differ diff --git a/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_10.png b/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_10.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_10.png and b/tests/data/cowc_counting/Columbus_CSUAV_AFRL/train/fake_10.png differ diff --git a/tests/data/cowc_counting/Potsdam_ISPRS/test/fake_05.png b/tests/data/cowc_counting/Potsdam_ISPRS/test/fake_05.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Potsdam_ISPRS/test/fake_05.png and b/tests/data/cowc_counting/Potsdam_ISPRS/test/fake_05.png differ diff --git a/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_05.png b/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_05.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_05.png and b/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_05.png differ diff --git a/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_06.png b/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_06.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_06.png and b/tests/data/cowc_counting/Potsdam_ISPRS/train/fake_06.png differ diff --git a/tests/data/cowc_counting/Selwyn_LINZ/test/fake_03.png b/tests/data/cowc_counting/Selwyn_LINZ/test/fake_03.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Selwyn_LINZ/test/fake_03.png and b/tests/data/cowc_counting/Selwyn_LINZ/test/fake_03.png differ diff --git a/tests/data/cowc_counting/Selwyn_LINZ/train/fake_03.png b/tests/data/cowc_counting/Selwyn_LINZ/train/fake_03.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Selwyn_LINZ/train/fake_03.png and b/tests/data/cowc_counting/Selwyn_LINZ/train/fake_03.png differ diff --git a/tests/data/cowc_counting/Selwyn_LINZ/train/fake_04.png b/tests/data/cowc_counting/Selwyn_LINZ/train/fake_04.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Selwyn_LINZ/train/fake_04.png and b/tests/data/cowc_counting/Selwyn_LINZ/train/fake_04.png differ diff --git a/tests/data/cowc_counting/Toronto_ISPRS/test/fake_01.png b/tests/data/cowc_counting/Toronto_ISPRS/test/fake_01.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Toronto_ISPRS/test/fake_01.png and b/tests/data/cowc_counting/Toronto_ISPRS/test/fake_01.png differ diff --git a/tests/data/cowc_counting/Toronto_ISPRS/train/fake_01.png b/tests/data/cowc_counting/Toronto_ISPRS/train/fake_01.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Toronto_ISPRS/train/fake_01.png and b/tests/data/cowc_counting/Toronto_ISPRS/train/fake_01.png differ diff --git a/tests/data/cowc_counting/Toronto_ISPRS/train/fake_02.png b/tests/data/cowc_counting/Toronto_ISPRS/train/fake_02.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Toronto_ISPRS/train/fake_02.png and b/tests/data/cowc_counting/Toronto_ISPRS/train/fake_02.png differ diff --git a/tests/data/cowc_counting/Utah_AGRC/test/fake_11.png b/tests/data/cowc_counting/Utah_AGRC/test/fake_11.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Utah_AGRC/test/fake_11.png and b/tests/data/cowc_counting/Utah_AGRC/test/fake_11.png differ diff --git a/tests/data/cowc_counting/Utah_AGRC/train/fake_11.png b/tests/data/cowc_counting/Utah_AGRC/train/fake_11.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Utah_AGRC/train/fake_11.png and b/tests/data/cowc_counting/Utah_AGRC/train/fake_11.png differ diff --git a/tests/data/cowc_counting/Utah_AGRC/train/fake_12.png b/tests/data/cowc_counting/Utah_AGRC/train/fake_12.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Utah_AGRC/train/fake_12.png and b/tests/data/cowc_counting/Utah_AGRC/train/fake_12.png differ diff --git a/tests/data/cowc_counting/Vaihingen_ISPRS/test/fake_07.png b/tests/data/cowc_counting/Vaihingen_ISPRS/test/fake_07.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Vaihingen_ISPRS/test/fake_07.png and b/tests/data/cowc_counting/Vaihingen_ISPRS/test/fake_07.png differ diff --git a/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_07.png b/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_07.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_07.png and b/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_07.png differ diff --git a/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_08.png b/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_08.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_08.png and b/tests/data/cowc_counting/Vaihingen_ISPRS/train/fake_08.png differ diff --git a/tests/data/cowc_counting/data.py b/tests/data/cowc_counting/data.py new file mode 100755 index 00000000000..b75bec2e1e0 --- /dev/null +++ b/tests/data/cowc_counting/data.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import bz2 +import csv +import hashlib +import glob +import os +import random +import shutil + +from PIL import Image + + +SIZE = 64 # image width/height +STOP = 20 # range of values for labels +PREFIX = "Counting" +SUFFIX = "64_class" + +random.seed(0) + +sites = [ + "Toronto_ISPRS", + "Selwyn_LINZ", + "Potsdam_ISPRS", + "Vaihingen_ISPRS", + "Columbus_CSUAV_AFRL", + "Utah_AGRC", +] + +# Remove old data +for filename in glob.glob("COWC_*"): + os.remove(filename) +for site in sites: + if os.path.exists(site): + shutil.rmtree(site) + +i = 1 +data_list = {"train": [], "test": []} +image_md5s = [] +for site in sites: + # Create images + for split in ["test", "train", "train"]: + directory = os.path.join(site, split) + os.makedirs(directory, exist_ok=True) + filename = os.path.join(directory, f"fake_{i:02}.png") + + img = Image.new("RGB", (SIZE, SIZE)) + img.save(filename) + + data_list[split].append((filename, random.randrange(STOP))) + + if split == "train": + i += 1 + + # Compress images + filename = f"COWC_{PREFIX}_{site}.tbz" + bad_filename = shutil.make_archive(filename.replace(".tbz", ""), "bztar", ".", site) + os.rename(bad_filename, filename) + + # Compute checksums + with open(filename, "rb") as f: + image_md5s.append(hashlib.md5(f.read()).hexdigest()) + +label_md5s = [] +for split in ["train", "test"]: + # Create labels + filename = f"COWC_{split}_list_{SUFFIX}.txt" + with open(filename, "w", newline="") as csvfile: + csvwriter = csv.writer(csvfile, delimiter=" ") + csvwriter.writerows(data_list[split]) + + # Compress labels + with open(filename, "rb") as src: + with bz2.open(filename + ".bz2", "wb") as dst: + dst.write(src.read()) + + # Compute checksums + with open(filename + ".bz2", "rb") as f: + label_md5s.append(hashlib.md5(f.read()).hexdigest()) + +md5s = label_md5s + image_md5s +for md5 in md5s: + print(repr(md5) + ",") diff --git a/tests/data/cowc_detection/COWC_Detection_Columbus_CSUAV_AFRL.tbz b/tests/data/cowc_detection/COWC_Detection_Columbus_CSUAV_AFRL.tbz index dff3032bd06..f42c6b4485d 100644 Binary files a/tests/data/cowc_detection/COWC_Detection_Columbus_CSUAV_AFRL.tbz and b/tests/data/cowc_detection/COWC_Detection_Columbus_CSUAV_AFRL.tbz differ diff --git a/tests/data/cowc_detection/COWC_Detection_Potsdam_ISPRS.tbz b/tests/data/cowc_detection/COWC_Detection_Potsdam_ISPRS.tbz index ed721194044..c60c8aa418d 100644 Binary files a/tests/data/cowc_detection/COWC_Detection_Potsdam_ISPRS.tbz and b/tests/data/cowc_detection/COWC_Detection_Potsdam_ISPRS.tbz differ diff --git a/tests/data/cowc_detection/COWC_Detection_Selwyn_LINZ.tbz b/tests/data/cowc_detection/COWC_Detection_Selwyn_LINZ.tbz index 9be43cea8e8..b3ace699c87 100644 Binary files a/tests/data/cowc_detection/COWC_Detection_Selwyn_LINZ.tbz and b/tests/data/cowc_detection/COWC_Detection_Selwyn_LINZ.tbz differ diff --git a/tests/data/cowc_detection/COWC_Detection_Toronto_ISPRS.tbz b/tests/data/cowc_detection/COWC_Detection_Toronto_ISPRS.tbz index 092b9ffd318..1ccc7f304a0 100644 Binary files a/tests/data/cowc_detection/COWC_Detection_Toronto_ISPRS.tbz and b/tests/data/cowc_detection/COWC_Detection_Toronto_ISPRS.tbz differ diff --git a/tests/data/cowc_detection/COWC_Detection_Utah_AGRC.tbz b/tests/data/cowc_detection/COWC_Detection_Utah_AGRC.tbz index 57d2faff089..9385f26f0cc 100644 Binary files a/tests/data/cowc_detection/COWC_Detection_Utah_AGRC.tbz and b/tests/data/cowc_detection/COWC_Detection_Utah_AGRC.tbz differ diff --git a/tests/data/cowc_detection/COWC_Detection_Vaihingen_ISPRS.tbz b/tests/data/cowc_detection/COWC_Detection_Vaihingen_ISPRS.tbz index 76c5ee0f37e..61b737c64a8 100644 Binary files a/tests/data/cowc_detection/COWC_Detection_Vaihingen_ISPRS.tbz and b/tests/data/cowc_detection/COWC_Detection_Vaihingen_ISPRS.tbz differ diff --git a/tests/data/cowc_detection/COWC_test_list_detection.txt b/tests/data/cowc_detection/COWC_test_list_detection.txt index ec3ddbab651..583e5231d41 100644 --- a/tests/data/cowc_detection/COWC_test_list_detection.txt +++ b/tests/data/cowc_detection/COWC_test_list_detection.txt @@ -1,6 +1,6 @@ -Toronto_ISPRS/test/fake_01.png 0 -Selwyn_LINZ/test/fake_03.png 1 -Potsdam_ISPRS/test/fake_05.png 1 -Vaihingen_ISPRS/test/fake_07.png 0 -Columbus_CSUAV_AFRL/test/fake_09.png 1 -Utah_AGRC/test/fake_11.png 1 +Toronto_ISPRS/test/fake_01.png 1 +Selwyn_LINZ/test/fake_03.png 1 +Potsdam_ISPRS/test/fake_05.png 1 +Vaihingen_ISPRS/test/fake_07.png 0 +Columbus_CSUAV_AFRL/test/fake_09.png 0 +Utah_AGRC/test/fake_11.png 0 diff --git a/tests/data/cowc_detection/COWC_test_list_detection.txt.bz2 b/tests/data/cowc_detection/COWC_test_list_detection.txt.bz2 index 9d82748e4e9..45b56567b03 100644 Binary files a/tests/data/cowc_detection/COWC_test_list_detection.txt.bz2 and b/tests/data/cowc_detection/COWC_test_list_detection.txt.bz2 differ diff --git a/tests/data/cowc_detection/COWC_train_list_detection.txt b/tests/data/cowc_detection/COWC_train_list_detection.txt index dfc082630c8..2eb48c8fa3b 100644 --- a/tests/data/cowc_detection/COWC_train_list_detection.txt +++ b/tests/data/cowc_detection/COWC_train_list_detection.txt @@ -1,12 +1,12 @@ -Toronto_ISPRS/train/fake_01.png 0 -Toronto_ISPRS/train/fake_02.png 1 -Selwyn_LINZ/train/fake_03.png 1 -Selwyn_LINZ/train/fake_04.png 0 -Potsdam_ISPRS/train/fake_05.png 1 -Potsdam_ISPRS/train/fake_06.png 1 -Vaihingen_ISPRS/train/fake_07.png 0 -Vaihingen_ISPRS/train/fake_08.png 0 -Columbus_CSUAV_AFRL/train/fake_09.png 1 -Columbus_CSUAV_AFRL/train/fake_10.png 1 -Utah_AGRC/train/fake_11.png 1 -Utah_AGRC/train/fake_12.png 1 +Toronto_ISPRS/train/fake_01.png 1 +Toronto_ISPRS/train/fake_02.png 0 +Selwyn_LINZ/train/fake_03.png 1 +Selwyn_LINZ/train/fake_04.png 1 +Potsdam_ISPRS/train/fake_05.png 1 +Potsdam_ISPRS/train/fake_06.png 1 +Vaihingen_ISPRS/train/fake_07.png 0 +Vaihingen_ISPRS/train/fake_08.png 1 +Columbus_CSUAV_AFRL/train/fake_09.png 0 +Columbus_CSUAV_AFRL/train/fake_10.png 1 +Utah_AGRC/train/fake_11.png 1 +Utah_AGRC/train/fake_12.png 0 diff --git a/tests/data/cowc_detection/COWC_train_list_detection.txt.bz2 b/tests/data/cowc_detection/COWC_train_list_detection.txt.bz2 index 4484776bbbb..0dcfe35b8da 100644 Binary files a/tests/data/cowc_detection/COWC_train_list_detection.txt.bz2 and b/tests/data/cowc_detection/COWC_train_list_detection.txt.bz2 differ diff --git a/tests/data/cowc_detection/Columbus_CSUAV_AFRL/test/fake_09.png b/tests/data/cowc_detection/Columbus_CSUAV_AFRL/test/fake_09.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Columbus_CSUAV_AFRL/test/fake_09.png and b/tests/data/cowc_detection/Columbus_CSUAV_AFRL/test/fake_09.png differ diff --git a/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_09.png b/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_09.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_09.png and b/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_09.png differ diff --git a/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_10.png b/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_10.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_10.png and b/tests/data/cowc_detection/Columbus_CSUAV_AFRL/train/fake_10.png differ diff --git a/tests/data/cowc_detection/Potsdam_ISPRS/test/fake_05.png b/tests/data/cowc_detection/Potsdam_ISPRS/test/fake_05.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Potsdam_ISPRS/test/fake_05.png and b/tests/data/cowc_detection/Potsdam_ISPRS/test/fake_05.png differ diff --git a/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_05.png b/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_05.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_05.png and b/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_05.png differ diff --git a/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_06.png b/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_06.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_06.png and b/tests/data/cowc_detection/Potsdam_ISPRS/train/fake_06.png differ diff --git a/tests/data/cowc_detection/Selwyn_LINZ/test/fake_03.png b/tests/data/cowc_detection/Selwyn_LINZ/test/fake_03.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Selwyn_LINZ/test/fake_03.png and b/tests/data/cowc_detection/Selwyn_LINZ/test/fake_03.png differ diff --git a/tests/data/cowc_detection/Selwyn_LINZ/train/fake_03.png b/tests/data/cowc_detection/Selwyn_LINZ/train/fake_03.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Selwyn_LINZ/train/fake_03.png and b/tests/data/cowc_detection/Selwyn_LINZ/train/fake_03.png differ diff --git a/tests/data/cowc_detection/Selwyn_LINZ/train/fake_04.png b/tests/data/cowc_detection/Selwyn_LINZ/train/fake_04.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Selwyn_LINZ/train/fake_04.png and b/tests/data/cowc_detection/Selwyn_LINZ/train/fake_04.png differ diff --git a/tests/data/cowc_detection/Toronto_ISPRS/test/fake_01.png b/tests/data/cowc_detection/Toronto_ISPRS/test/fake_01.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Toronto_ISPRS/test/fake_01.png and b/tests/data/cowc_detection/Toronto_ISPRS/test/fake_01.png differ diff --git a/tests/data/cowc_detection/Toronto_ISPRS/train/fake_01.png b/tests/data/cowc_detection/Toronto_ISPRS/train/fake_01.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Toronto_ISPRS/train/fake_01.png and b/tests/data/cowc_detection/Toronto_ISPRS/train/fake_01.png differ diff --git a/tests/data/cowc_detection/Toronto_ISPRS/train/fake_02.png b/tests/data/cowc_detection/Toronto_ISPRS/train/fake_02.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Toronto_ISPRS/train/fake_02.png and b/tests/data/cowc_detection/Toronto_ISPRS/train/fake_02.png differ diff --git a/tests/data/cowc_detection/Utah_AGRC/test/fake_11.png b/tests/data/cowc_detection/Utah_AGRC/test/fake_11.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Utah_AGRC/test/fake_11.png and b/tests/data/cowc_detection/Utah_AGRC/test/fake_11.png differ diff --git a/tests/data/cowc_detection/Utah_AGRC/train/fake_11.png b/tests/data/cowc_detection/Utah_AGRC/train/fake_11.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Utah_AGRC/train/fake_11.png and b/tests/data/cowc_detection/Utah_AGRC/train/fake_11.png differ diff --git a/tests/data/cowc_detection/Utah_AGRC/train/fake_12.png b/tests/data/cowc_detection/Utah_AGRC/train/fake_12.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Utah_AGRC/train/fake_12.png and b/tests/data/cowc_detection/Utah_AGRC/train/fake_12.png differ diff --git a/tests/data/cowc_detection/Vaihingen_ISPRS/test/fake_07.png b/tests/data/cowc_detection/Vaihingen_ISPRS/test/fake_07.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Vaihingen_ISPRS/test/fake_07.png and b/tests/data/cowc_detection/Vaihingen_ISPRS/test/fake_07.png differ diff --git a/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_07.png b/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_07.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_07.png and b/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_07.png differ diff --git a/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_08.png b/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_08.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_08.png and b/tests/data/cowc_detection/Vaihingen_ISPRS/train/fake_08.png differ diff --git a/tests/data/cowc_detection/data.py b/tests/data/cowc_detection/data.py new file mode 100755 index 00000000000..f43a57e6439 --- /dev/null +++ b/tests/data/cowc_detection/data.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import bz2 +import csv +import hashlib +import glob +import os +import random +import shutil + +from PIL import Image + + +SIZE = 64 # image width/height +STOP = 2 # range of values for labels +PREFIX = "Detection" +SUFFIX = "detection" + +random.seed(0) + +sites = [ + "Toronto_ISPRS", + "Selwyn_LINZ", + "Potsdam_ISPRS", + "Vaihingen_ISPRS", + "Columbus_CSUAV_AFRL", + "Utah_AGRC", +] + +# Remove old data +for filename in glob.glob("COWC_*"): + os.remove(filename) +for site in sites: + if os.path.exists(site): + shutil.rmtree(site) + +i = 1 +data_list = {"train": [], "test": []} +image_md5s = [] +for site in sites: + # Create images + for split in ["test", "train", "train"]: + directory = os.path.join(site, split) + os.makedirs(directory, exist_ok=True) + filename = os.path.join(directory, f"fake_{i:02}.png") + + img = Image.new("RGB", (SIZE, SIZE)) + img.save(filename) + + data_list[split].append((filename, random.randrange(STOP))) + + if split == "train": + i += 1 + + # Compress images + filename = f"COWC_{PREFIX}_{site}.tbz" + bad_filename = shutil.make_archive(filename.replace(".tbz", ""), "bztar", ".", site) + os.rename(bad_filename, filename) + + # Compute checksums + with open(filename, "rb") as f: + image_md5s.append(hashlib.md5(f.read()).hexdigest()) + +label_md5s = [] +for split in ["train", "test"]: + # Create labels + filename = f"COWC_{split}_list_{SUFFIX}.txt" + with open(filename, "w", newline="") as csvfile: + csvwriter = csv.writer(csvfile, delimiter=" ") + csvwriter.writerows(data_list[split]) + + # Compress labels + with open(filename, "rb") as src: + with bz2.open(filename + ".bz2", "wb") as dst: + dst.write(src.read()) + + # Compute checksums + with open(filename + ".bz2", "rb") as f: + label_md5s.append(hashlib.md5(f.read()).hexdigest()) + +md5s = label_md5s + image_md5s +for md5 in md5s: + print(repr(md5) + ",") diff --git a/tests/data/etci2021/data.py b/tests/data/etci2021/data.py new file mode 100755 index 00000000000..03d14d640b4 --- /dev/null +++ b/tests/data/etci2021/data.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +from PIL import Image + +SIZE = 64 # image width/height + +metadatas = [ + { + "filename": "train.zip", + "directory": "train", + "subdirs": ["nebraska_20170108t002112", "bangladesh_20170314t115609"], + }, + { + "filename": "val_with_ref_labels.zip", + "directory": "test", + "subdirs": ["florence_20180510t231343", "florence_20180522t231344"], + }, + { + "filename": "test_without_ref_labels.zip", + "directory": "test_internal", + "subdirs": ["redrivernorth_20190104t002247", "redrivernorth_20190116t002247"], + }, +] + +tiles = ["vh", "vv", "water_body_label", "flood_label"] + +for metadata in metadatas: + filename = metadata["filename"] + directory = metadata["directory"] + + # Remove old data + if os.path.exists(filename): + os.remove(filename) + if os.path.exists(directory): + shutil.rmtree(directory) + + # Create images + for subdir in metadata["subdirs"]: + for tile in tiles: + if directory == "test_internal" and tile == "flood_label": + continue + + fn = f"{subdir}_x-0_y-0" + if tile in ["vh", "vv"]: + fn += f"_{tile}" + fn += ".png" + fd = os.path.join(directory, subdir, "tiles", tile) + os.makedirs(fd) + + img = Image.new("RGB", (SIZE, SIZE)) + img.save(os.path.join(fd, fn)) + + # Compress data + shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory) + + # Compute checksums + with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(repr(filename) + ":", repr(md5) + ",") diff --git a/tests/data/etci2021/test/florence_20180510t231343/tiles/flood_label/florence_20180510t231343_x-0_y-0.png b/tests/data/etci2021/test/florence_20180510t231343/tiles/flood_label/florence_20180510t231343_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180510t231343/tiles/flood_label/florence_20180510t231343_x-0_y-0.png and b/tests/data/etci2021/test/florence_20180510t231343/tiles/flood_label/florence_20180510t231343_x-0_y-0.png differ diff --git a/tests/data/etci2021/test/florence_20180510t231343/tiles/vh/florence_20180510t231343_x-0_y-0_vh.png b/tests/data/etci2021/test/florence_20180510t231343/tiles/vh/florence_20180510t231343_x-0_y-0_vh.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180510t231343/tiles/vh/florence_20180510t231343_x-0_y-0_vh.png and b/tests/data/etci2021/test/florence_20180510t231343/tiles/vh/florence_20180510t231343_x-0_y-0_vh.png differ diff --git a/tests/data/etci2021/test/florence_20180510t231343/tiles/vv/florence_20180510t231343_x-0_y-0_vv.png b/tests/data/etci2021/test/florence_20180510t231343/tiles/vv/florence_20180510t231343_x-0_y-0_vv.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180510t231343/tiles/vv/florence_20180510t231343_x-0_y-0_vv.png and b/tests/data/etci2021/test/florence_20180510t231343/tiles/vv/florence_20180510t231343_x-0_y-0_vv.png differ diff --git a/tests/data/etci2021/test/florence_20180510t231343/tiles/water_body_label/florence_20180510t231343_x-0_y-0.png b/tests/data/etci2021/test/florence_20180510t231343/tiles/water_body_label/florence_20180510t231343_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180510t231343/tiles/water_body_label/florence_20180510t231343_x-0_y-0.png and b/tests/data/etci2021/test/florence_20180510t231343/tiles/water_body_label/florence_20180510t231343_x-0_y-0.png differ diff --git a/tests/data/etci2021/test/florence_20180522t231344/tiles/flood_label/florence_20180522t231344_x-0_y-0.png b/tests/data/etci2021/test/florence_20180522t231344/tiles/flood_label/florence_20180522t231344_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180522t231344/tiles/flood_label/florence_20180522t231344_x-0_y-0.png and b/tests/data/etci2021/test/florence_20180522t231344/tiles/flood_label/florence_20180522t231344_x-0_y-0.png differ diff --git a/tests/data/etci2021/test/florence_20180522t231344/tiles/vh/florence_20180522t231344_x-0_y-0_vh.png b/tests/data/etci2021/test/florence_20180522t231344/tiles/vh/florence_20180522t231344_x-0_y-0_vh.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180522t231344/tiles/vh/florence_20180522t231344_x-0_y-0_vh.png and b/tests/data/etci2021/test/florence_20180522t231344/tiles/vh/florence_20180522t231344_x-0_y-0_vh.png differ diff --git a/tests/data/etci2021/test/florence_20180522t231344/tiles/vv/florence_20180522t231344_x-0_y-0_vv.png b/tests/data/etci2021/test/florence_20180522t231344/tiles/vv/florence_20180522t231344_x-0_y-0_vv.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180522t231344/tiles/vv/florence_20180522t231344_x-0_y-0_vv.png and b/tests/data/etci2021/test/florence_20180522t231344/tiles/vv/florence_20180522t231344_x-0_y-0_vv.png differ diff --git a/tests/data/etci2021/test/florence_20180522t231344/tiles/water_body_label/florence_20180522t231344_x-0_y-0.png b/tests/data/etci2021/test/florence_20180522t231344/tiles/water_body_label/florence_20180522t231344_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test/florence_20180522t231344/tiles/water_body_label/florence_20180522t231344_x-0_y-0.png and b/tests/data/etci2021/test/florence_20180522t231344/tiles/water_body_label/florence_20180522t231344_x-0_y-0.png differ diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vh/redrivernorth_20190104t002247_x-0_y-0_vh.png b/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vh/redrivernorth_20190104t002247_x-0_y-0_vh.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vh/redrivernorth_20190104t002247_x-0_y-0_vh.png and b/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vh/redrivernorth_20190104t002247_x-0_y-0_vh.png differ diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vv/redrivernorth_20190104t002247_x-0_y-0_vv.png b/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vv/redrivernorth_20190104t002247_x-0_y-0_vv.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vv/redrivernorth_20190104t002247_x-0_y-0_vv.png and b/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/vv/redrivernorth_20190104t002247_x-0_y-0_vv.png differ diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/water_body_label/redrivernorth_20190104t002247_x-0_y-0.png b/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/water_body_label/redrivernorth_20190104t002247_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/water_body_label/redrivernorth_20190104t002247_x-0_y-0.png and b/tests/data/etci2021/test_internal/redrivernorth_20190104t002247/tiles/water_body_label/redrivernorth_20190104t002247_x-0_y-0.png differ diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vh/redrivernorth_20190116t002247_x-0_y-0_vh.png b/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vh/redrivernorth_20190116t002247_x-0_y-0_vh.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vh/redrivernorth_20190116t002247_x-0_y-0_vh.png and b/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vh/redrivernorth_20190116t002247_x-0_y-0_vh.png differ diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vv/redrivernorth_20190116t002247_x-0_y-0_vv.png b/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vv/redrivernorth_20190116t002247_x-0_y-0_vv.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vv/redrivernorth_20190116t002247_x-0_y-0_vv.png and b/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/vv/redrivernorth_20190116t002247_x-0_y-0_vv.png differ diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/water_body_label/redrivernorth_20190116t002247_x-0_y-0.png b/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/water_body_label/redrivernorth_20190116t002247_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/water_body_label/redrivernorth_20190116t002247_x-0_y-0.png and b/tests/data/etci2021/test_internal/redrivernorth_20190116t002247/tiles/water_body_label/redrivernorth_20190116t002247_x-0_y-0.png differ diff --git a/tests/data/etci2021/test_without_ref_labels.zip b/tests/data/etci2021/test_without_ref_labels.zip index 0f94a1e4ce5..34e062dbc3e 100644 Binary files a/tests/data/etci2021/test_without_ref_labels.zip and b/tests/data/etci2021/test_without_ref_labels.zip differ diff --git a/tests/data/etci2021/train.zip b/tests/data/etci2021/train.zip index 5cba717d7a1..b24e36421b2 100644 Binary files a/tests/data/etci2021/train.zip and b/tests/data/etci2021/train.zip differ diff --git a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/flood_label/bangladesh_20170314t115609_x-0_y-0.png b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/flood_label/bangladesh_20170314t115609_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/flood_label/bangladesh_20170314t115609_x-0_y-0.png and b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/flood_label/bangladesh_20170314t115609_x-0_y-0.png differ diff --git a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vh/bangladesh_20170314t115609_x-0_y-0_vh.png b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vh/bangladesh_20170314t115609_x-0_y-0_vh.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vh/bangladesh_20170314t115609_x-0_y-0_vh.png and b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vh/bangladesh_20170314t115609_x-0_y-0_vh.png differ diff --git a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vv/bangladesh_20170314t115609_x-0_y-0_vv.png b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vv/bangladesh_20170314t115609_x-0_y-0_vv.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vv/bangladesh_20170314t115609_x-0_y-0_vv.png and b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/vv/bangladesh_20170314t115609_x-0_y-0_vv.png differ diff --git a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/water_body_label/bangladesh_20170314t115609_x-0_y-0.png b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/water_body_label/bangladesh_20170314t115609_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/water_body_label/bangladesh_20170314t115609_x-0_y-0.png and b/tests/data/etci2021/train/bangladesh_20170314t115609/tiles/water_body_label/bangladesh_20170314t115609_x-0_y-0.png differ diff --git a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/flood_label/nebraska_20170108t002112_x-0_y-0.png b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/flood_label/nebraska_20170108t002112_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/flood_label/nebraska_20170108t002112_x-0_y-0.png and b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/flood_label/nebraska_20170108t002112_x-0_y-0.png differ diff --git a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vh/nebraska_20170108t002112_x-0_y-0_vh.png b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vh/nebraska_20170108t002112_x-0_y-0_vh.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vh/nebraska_20170108t002112_x-0_y-0_vh.png and b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vh/nebraska_20170108t002112_x-0_y-0_vh.png differ diff --git a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vv/nebraska_20170108t002112_x-0_y-0_vv.png b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vv/nebraska_20170108t002112_x-0_y-0_vv.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vv/nebraska_20170108t002112_x-0_y-0_vv.png and b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/vv/nebraska_20170108t002112_x-0_y-0_vv.png differ diff --git a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/water_body_label/nebraska_20170108t002112_x-0_y-0.png b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/water_body_label/nebraska_20170108t002112_x-0_y-0.png index 19b3f6a1e4b..320c3449e5f 100644 Binary files a/tests/data/etci2021/train/nebraska_20170108t002112/tiles/water_body_label/nebraska_20170108t002112_x-0_y-0.png and b/tests/data/etci2021/train/nebraska_20170108t002112/tiles/water_body_label/nebraska_20170108t002112_x-0_y-0.png differ diff --git a/tests/data/etci2021/val_with_ref_labels.zip b/tests/data/etci2021/val_with_ref_labels.zip index acd6d76c332..6077f058efd 100644 Binary files a/tests/data/etci2021/val_with_ref_labels.zip and b/tests/data/etci2021/val_with_ref_labels.zip differ diff --git a/tests/data/levircd/LEVIR-CD+.zip b/tests/data/levircd/LEVIR-CD+.zip index b51dc099207..9a5fa4e1a7c 100644 Binary files a/tests/data/levircd/LEVIR-CD+.zip and b/tests/data/levircd/LEVIR-CD+.zip differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images.zip b/tests/data/oscd/Onera Satellite Change Detection dataset - Images.zip index 9a6cb41a1b9..e686d175811 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images.zip and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images.zip differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/dates.txt b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/dates.txt index fa2ff53aa38..7ee094e0852 100644 --- a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/dates.txt +++ b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/dates.txt @@ -1,2 +1,2 @@ -date_1: 20151211 -date_2: 20180330 +date_1: 20161130 +date_2: 20170829 diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B01.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B01.tif index a3ee8290f6a..31595ee57ee 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B01.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B01.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B02.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B02.tif index a3ee8290f6a..5e14a9b525a 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B02.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B02.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B03.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B03.tif index a3ee8290f6a..6ac049ea88b 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B03.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B03.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B04.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B04.tif index a3ee8290f6a..718b05845f1 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B04.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B04.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B05.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B05.tif index a3ee8290f6a..c6356789554 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B05.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B05.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B06.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B06.tif index a3ee8290f6a..48b17dda24f 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B06.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B06.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B07.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B07.tif index a3ee8290f6a..acf9bfb4cee 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B07.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B07.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B08.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B08.tif index a3ee8290f6a..b9eefaaa743 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B08.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B08.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B09.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B09.tif index a3ee8290f6a..5120b3b91d2 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B09.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B09.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B10.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B10.tif index a3ee8290f6a..b5b6e436327 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B10.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B10.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B11.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B11.tif index a3ee8290f6a..047a39938ff 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B11.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B11.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B12.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B12.tif index a3ee8290f6a..e47a97810de 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B12.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B12.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B8A.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B8A.tif index a3ee8290f6a..0be3b5f5bcf 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B8A.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_1_rect/B8A.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B01.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B01.tif index a3ee8290f6a..a7c74bb99e9 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B01.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B01.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B02.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B02.tif index a3ee8290f6a..9809f480283 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B02.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B02.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B03.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B03.tif index a3ee8290f6a..5d891f19bf1 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B03.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B03.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B04.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B04.tif index a3ee8290f6a..3989aa999e7 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B04.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B04.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B05.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B05.tif index a3ee8290f6a..323f1629690 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B05.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B05.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B06.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B06.tif index a3ee8290f6a..647bac56eed 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B06.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B06.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B07.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B07.tif index a3ee8290f6a..895e6c6b789 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B07.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B07.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B08.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B08.tif index a3ee8290f6a..8067466b418 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B08.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B08.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B09.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B09.tif index a3ee8290f6a..e95a142bf60 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B09.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B09.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B10.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B10.tif index a3ee8290f6a..39f65baa3f5 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B10.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B10.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B11.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B11.tif index a3ee8290f6a..ec0879ca82c 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B11.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B11.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B12.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B12.tif index a3ee8290f6a..92b30223833 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B12.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B12.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B8A.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B8A.tif index a3ee8290f6a..1b130a20bf6 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B8A.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/test/imgs_2_rect/B8A.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B01.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B01.tif index a3ee8290f6a..0fb1192178d 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B01.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B01.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B02.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B02.tif index a3ee8290f6a..848c4441980 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B02.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B02.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B03.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B03.tif index a3ee8290f6a..df966b575f9 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B03.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B03.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B04.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B04.tif index a3ee8290f6a..93d01cb2a79 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B04.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B04.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B05.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B05.tif index a3ee8290f6a..50a214cdf56 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B05.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B05.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B06.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B06.tif index a3ee8290f6a..56751414a67 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B06.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B06.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B07.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B07.tif index a3ee8290f6a..ea57e5466b5 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B07.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B07.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B08.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B08.tif index a3ee8290f6a..5dbc33d5122 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B08.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B08.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B09.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B09.tif index a3ee8290f6a..fb3c4ada78c 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B09.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B09.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B10.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B10.tif index a3ee8290f6a..8b7528cefb6 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B10.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B10.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B11.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B11.tif index a3ee8290f6a..c742ef9091c 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B11.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B11.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B12.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B12.tif index a3ee8290f6a..e285b7cb8e2 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B12.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B12.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B8A.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B8A.tif index a3ee8290f6a..8dad3bd0b35 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B8A.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_1_rect/B8A.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B01.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B01.tif index a3ee8290f6a..b61a82720ed 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B01.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B01.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B02.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B02.tif index a3ee8290f6a..49af5358221 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B02.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B02.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B03.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B03.tif index a3ee8290f6a..5d0a0a07cfb 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B03.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B03.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B04.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B04.tif index a3ee8290f6a..dea8b061347 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B04.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B04.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B05.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B05.tif index a3ee8290f6a..bc688d18120 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B05.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B05.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B06.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B06.tif index a3ee8290f6a..c05df52d962 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B06.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B06.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B07.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B07.tif index a3ee8290f6a..2b432d47c81 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B07.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B07.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B08.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B08.tif index a3ee8290f6a..96ce65b4923 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B08.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B08.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B09.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B09.tif index a3ee8290f6a..d4e1b17c5af 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B09.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B09.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B10.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B10.tif index a3ee8290f6a..8747b6b4fe8 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B10.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B10.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B11.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B11.tif index a3ee8290f6a..2b024f607d0 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B11.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B11.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B12.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B12.tif index a3ee8290f6a..0c2643b23f7 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B12.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B12.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B8A.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B8A.tif index a3ee8290f6a..26048023330 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B8A.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train1/imgs_2_rect/B8A.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B01.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B01.tif index a3ee8290f6a..b64da370d6b 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B01.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B01.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B02.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B02.tif index a3ee8290f6a..3ec91badd61 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B02.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B02.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B03.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B03.tif index a3ee8290f6a..28a3a7192e7 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B03.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B03.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B04.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B04.tif index a3ee8290f6a..c6231ff3c15 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B04.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B04.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B05.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B05.tif index a3ee8290f6a..29dfd694763 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B05.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B05.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B06.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B06.tif index a3ee8290f6a..ff8430de80b 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B06.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B06.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B07.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B07.tif index a3ee8290f6a..f093b329dc6 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B07.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B07.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B08.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B08.tif index a3ee8290f6a..54a404a54ff 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B08.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B08.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B09.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B09.tif index a3ee8290f6a..d3a60ecdb96 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B09.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B09.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B10.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B10.tif index a3ee8290f6a..cd13329201b 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B10.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B10.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B11.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B11.tif index a3ee8290f6a..a1ca762b10e 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B11.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B11.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B12.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B12.tif index a3ee8290f6a..135c138b1fb 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B12.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B12.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B8A.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B8A.tif index a3ee8290f6a..33066a9c723 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B8A.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_1_rect/B8A.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B01.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B01.tif index a3ee8290f6a..1b0efa4bde1 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B01.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B01.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B02.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B02.tif index a3ee8290f6a..f7f9baf720e 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B02.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B02.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B03.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B03.tif index a3ee8290f6a..e04bc7a7eca 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B03.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B03.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B04.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B04.tif index a3ee8290f6a..e1ef595b16a 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B04.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B04.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B05.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B05.tif index a3ee8290f6a..c0d140aaa7d 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B05.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B05.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B06.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B06.tif index a3ee8290f6a..2c0614824bf 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B06.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B06.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B07.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B07.tif index a3ee8290f6a..2e100588d73 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B07.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B07.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B08.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B08.tif index a3ee8290f6a..03e0da41934 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B08.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B08.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B09.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B09.tif index a3ee8290f6a..a4a1c1fb076 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B09.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B09.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B10.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B10.tif index a3ee8290f6a..99e4f640035 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B10.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B10.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B11.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B11.tif index a3ee8290f6a..e4704ad5969 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B11.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B11.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B12.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B12.tif index a3ee8290f6a..6db801c9dde 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B12.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B12.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B8A.tif b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B8A.tif index a3ee8290f6a..00c9abada70 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B8A.tif and b/tests/data/oscd/Onera Satellite Change Detection dataset - Images/train2/imgs_2_rect/B8A.tif differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels.zip b/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels.zip index 04ec4f8e338..ce1415975d0 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels.zip and b/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels.zip differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels/test/cm/cm.png b/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels/test/cm/cm.png index 337cd665214..3abc8741c86 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels/test/cm/cm.png and b/tests/data/oscd/Onera Satellite Change Detection dataset - Test Labels/test/cm/cm.png differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels.zip b/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels.zip index 4332a78b828..6a91068749e 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels.zip and b/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels.zip differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train1/cm/cm.png b/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train1/cm/cm.png index 337cd665214..131f6e25241 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train1/cm/cm.png and b/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train1/cm/cm.png differ diff --git a/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train2/cm/cm.png b/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train2/cm/cm.png index 337cd665214..a6869542b63 100644 Binary files a/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train2/cm/cm.png and b/tests/data/oscd/Onera Satellite Change Detection dataset - Train Labels/train2/cm/cm.png differ diff --git a/tests/data/oscd/data.py b/tests/data/oscd/data.py new file mode 100755 index 00000000000..9f576720eee --- /dev/null +++ b/tests/data/oscd/data.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +from PIL import Image + + +SIZE = 64 # image width/height + +np.random.seed(0) + +directories = [ + "Onera Satellite Change Detection dataset - Images", + "Onera Satellite Change Detection dataset - Train Labels", + "Onera Satellite Change Detection dataset - Test Labels", +] +bands = [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B09", + "B10", + "B11", + "B12", + "B8A", +] + +# Remove old data +for directory in directories: + filename = f"{directory}.zip" + + if os.path.exists(filename): + os.remove(filename) + if os.path.exists(directory): + shutil.rmtree(directory) + +# Create images +for subdir in ["train1", "train2", "test"]: + for rect in ["imgs_1_rect", "imgs_2_rect"]: + directory = os.path.join(directories[0], subdir, rect) + os.makedirs(directory) + + for band in bands: + filename = os.path.join(directory, f"{band}.tif") + arr = np.random.randint( + np.iinfo(np.uint16).max, size=(SIZE, SIZE), dtype=np.uint16 + ) + img = Image.fromarray(arr) + img.save(filename) + + filename = os.path.join(directories[0], subdir, "dates.txt") + with open(filename, "w") as f: + for key, value in [("date_1", "20161130"), ("date_2", "20170829")]: + f.write(f"{key}: {value}\n") + +# Create labels +for i, subdir in [(1, "train1"), (1, "train2"), (2, "test")]: + directory = os.path.join(directories[i], subdir, "cm") + os.makedirs(directory) + filename = os.path.join(directory, "cm.png") + arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8) + img = Image.fromarray(arr) + img.save(filename) + +for directory in directories: + # Compress data + shutil.make_archive(directory, "zip", ".", directory) + + # Compute checksums + filename = f"{directory}.zip" + with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(repr(filename) + ": " + repr(md5) + ",") diff --git a/tests/data/so2sat/data.py b/tests/data/so2sat/data.py new file mode 100755 index 00000000000..f0a2c01f2aa --- /dev/null +++ b/tests/data/so2sat/data.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os + +import h5py +import numpy as np + + +SIZE = 64 # image width/height +NUM_CLASSES = 17 +NUM_SAMPLES = 1 + +np.random.seed(0) + +for split in ["training", "validation", "testing"]: + filename = f"{split}.h5" + + # Remove old data + if os.path.exists(filename): + os.remove(filename) + + # Random one hot encoding + label = np.eye(NUM_CLASSES, dtype=np.uint8)[ + np.random.choice(NUM_CLASSES, NUM_SAMPLES) + ] + + # Random images + sen1 = np.random.randint(256, size=(NUM_SAMPLES, SIZE, SIZE, 8), dtype=np.uint8) + sen2 = np.random.randint(256, size=(NUM_SAMPLES, SIZE, SIZE, 10), dtype=np.uint8) + + # Create datasets + with h5py.File(filename, "w") as f: + f.create_dataset("label", data=label) + f.create_dataset("sen1", data=sen1) + f.create_dataset("sen2", data=sen2) + + # Compute checksums + with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(repr(split.replace("ing", "")) + ":", repr(md5) + ",") diff --git a/tests/data/so2sat/testing.h5 b/tests/data/so2sat/testing.h5 index 51169280abb..464108e4b9d 100644 Binary files a/tests/data/so2sat/testing.h5 and b/tests/data/so2sat/testing.h5 differ diff --git a/tests/data/so2sat/training.h5 b/tests/data/so2sat/training.h5 index 7de7b665c33..b433518217f 100644 Binary files a/tests/data/so2sat/training.h5 and b/tests/data/so2sat/training.h5 differ diff --git a/tests/data/so2sat/validation.h5 b/tests/data/so2sat/validation.h5 index f959a2aced6..a1ec7f5d1da 100644 Binary files a/tests/data/so2sat/validation.h5 and b/tests/data/so2sat/validation.h5 differ diff --git a/tests/data/xview3/sample_data.tar.gz b/tests/data/xview3/sample_data.tar.gz new file mode 100644 index 00000000000..47639f1990d Binary files /dev/null and b/tests/data/xview3/sample_data.tar.gz differ diff --git a/tests/datamodules/__init__.py b/tests/datamodules/__init__.py new file mode 100644 index 00000000000..5b7f7a925cc --- /dev/null +++ b/tests/datamodules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. diff --git a/tests/datamodules/test_chesapeake.py b/tests/datamodules/test_chesapeake.py new file mode 100644 index 00000000000..4dde9bbf428 --- /dev/null +++ b/tests/datamodules/test_chesapeake.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Any, Dict, cast + +import pytest +import torch +from omegaconf import OmegaConf + +from torchgeo.datamodules import ChesapeakeCVPRDataModule + + +class TestChesapeakeCVPRDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> ChesapeakeCVPRDataModule: + conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "chesapeake_cvpr_5.yaml") + ) + kwargs = OmegaConf.to_object(conf.experiment.datamodule) + kwargs = cast(Dict[str, Any], kwargs) + + datamodule = ChesapeakeCVPRDataModule(**kwargs) + datamodule.prepare_data() + datamodule.setup() + return datamodule + + def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None: + nodata_check = datamodule.nodata_check(4) + sample = { + "image": torch.ones(1, 2, 2), # type: ignore[attr-defined] + "mask": torch.ones(2, 2), # type: ignore[attr-defined] + } + out = nodata_check(sample) + assert torch.equal( # type: ignore[attr-defined] + out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined] + ) + assert torch.equal( # type: ignore[attr-defined] + out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined] + ) + + def test_invalid_param_config(self) -> None: + with pytest.raises(ValueError, match="The pre-generated prior labels"): + ChesapeakeCVPRDataModule( + os.path.join("tests", "data", "chesapeake", "cvpr"), + ["de-test"], + ["de-test"], + ["de-test"], + patch_size=32, + patches_per_tile=2, + batch_size=2, + num_workers=0, + class_set=7, + use_prior_labels=True, + ) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py new file mode 100644 index 00000000000..1f19922f1eb --- /dev/null +++ b/tests/datamodules/test_fair1m.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import FAIR1MDataModule + + +class TestFAIR1MDataModule: + @pytest.fixture(scope="class", params=[True, False]) + def datamodule(self) -> FAIR1MDataModule: + root = os.path.join("tests", "data", "fair1m") + batch_size = 2 + num_workers = 0 + dm = FAIR1MDataModule( + root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33 + ) + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_loveda.py b/tests/datamodules/test_loveda.py new file mode 100644 index 00000000000..c19e8cb0ab9 --- /dev/null +++ b/tests/datamodules/test_loveda.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import LoveDADataModule + + +class TestLoveDADataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> LoveDADataModule: + root = os.path.join("tests", "data", "loveda") + batch_size = 2 + num_workers = 0 + scene = ["rural", "urban"] + + dm = LoveDADataModule( + root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers + ) + + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: LoveDADataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: LoveDADataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: LoveDADataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_nasa_marine_debris.py b/tests/datamodules/test_nasa_marine_debris.py new file mode 100644 index 00000000000..eff571f953c --- /dev/null +++ b/tests/datamodules/test_nasa_marine_debris.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import NASAMarineDebrisDataModule + + +class TestNASAMarineDebrisDataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> NASAMarineDebrisDataModule: + root = os.path.join("tests", "data", "nasa_marine_debris") + batch_size = 2 + num_workers = 0 + val_split_pct = 0.3 + test_split_pct = 0.3 + dm = NASAMarineDebrisDataModule( + root, batch_size, num_workers, val_split_pct, test_split_pct + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py new file mode 100644 index 00000000000..7d090f99c97 --- /dev/null +++ b/tests/datamodules/test_oscd.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import OSCDDataModule + + +class TestOSCDDataModule: + @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) + def datamodule(self, request: SubRequest) -> OSCDDataModule: + bands, val_split_pct = request.param + patch_size = (2, 2) + num_patches_per_tile = 2 + root = os.path.join("tests", "data", "oscd") + batch_size = 1 + num_workers = 0 + dm = OSCDDataModule( + root, + bands, + batch_size, + num_workers, + val_split_pct, + patch_size, + num_patches_per_tile, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: + sample = next(iter(datamodule.train_dataloader())) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 + if datamodule.bands == "all": + assert sample["image"].shape[1] == 26 + else: + assert sample["image"].shape[1] == 6 + + def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: + sample = next(iter(datamodule.val_dataloader())) + if datamodule.val_split_pct > 0.0: + assert ( + sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) + ) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert sample["image"].shape[1] == 26 + else: + assert sample["image"].shape[1] == 6 + + def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: + sample = next(iter(datamodule.test_dataloader())) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) + assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert sample["image"].shape[1] == 26 + else: + assert sample["image"].shape[1] == 6 diff --git a/tests/datamodules/test_potsdam.py b/tests/datamodules/test_potsdam.py new file mode 100644 index 00000000000..f67be0fea7c --- /dev/null +++ b/tests/datamodules/test_potsdam.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import Potsdam2DDataModule + + +class TestPotsdam2DDataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> Potsdam2DDataModule: + root = os.path.join("tests", "data", "potsdam") + batch_size = 1 + num_workers = 0 + val_split_size = request.param + dm = Potsdam2DDataModule( + root, batch_size, num_workers, val_split_pct=val_split_size + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py new file mode 100644 index 00000000000..e5bc527f6c3 --- /dev/null +++ b/tests/datamodules/test_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import torch +from torch.utils.data import TensorDataset + +from torchgeo.datamodules.utils import dataset_split + + +def test_dataset_split() -> None: + num_samples = 24 + x = torch.ones(num_samples, 5) # type: ignore[attr-defined] + y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined] + ds = TensorDataset(x, y) + + # Test only train/val set split + train_ds, val_ds = dataset_split(ds, val_pct=1 / 2) + assert len(train_ds) == num_samples // 2 + assert len(val_ds) == num_samples // 2 + + # Test train/val/test set split + train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3) + assert len(train_ds) == num_samples // 3 + assert len(val_ds) == num_samples // 3 + assert len(test_ds) == num_samples // 3 diff --git a/tests/datamodules/test_vaihingen.py b/tests/datamodules/test_vaihingen.py new file mode 100644 index 00000000000..453a987ecef --- /dev/null +++ b/tests/datamodules/test_vaihingen.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import Vaihingen2DDataModule + + +class TestVaihingen2DDataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: + root = os.path.join("tests", "data", "vaihingen") + batch_size = 1 + num_workers = 0 + val_split_size = request.param + dm = Vaihingen2DDataModule( + root, batch_size, num_workers, val_split_pct=val_split_size + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py new file mode 100644 index 00000000000..5e1637533d6 --- /dev/null +++ b/tests/datamodules/test_xview2.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.datamodules import XView2DataModule + + +class TestXView2DataModule: + @pytest.fixture(scope="class", params=[0.0, 0.5]) + def datamodule(self, request: SubRequest) -> XView2DataModule: + root = os.path.join("tests", "data", "xview2") + batch_size = 1 + num_workers = 0 + val_split_size = request.param + dm = XView2DataModule( + root, batch_size, num_workers, val_split_pct=val_split_size + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: XView2DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: XView2DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: XView2DataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index 30387142b53..ff2283c83d3 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -86,3 +87,13 @@ def test_mock_missing_module( match="scipy is not installed and is required to use this dataset", ): dataset[0] + + def test_plot(self, dataset: ADVANCE) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index f4f6d285f2b..ea19ad3a378 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -49,9 +50,12 @@ def dataset( ) root = str(tmp_path) transforms = nn.Identity() # type: ignore[attr-defined] + bands = BeninSmallHolderCashews.ALL_BANDS + return BeninSmallHolderCashews( root, transforms=transforms, + bands=bands, download=True, api_key="", checksum=True, @@ -87,3 +91,19 @@ def test_invalid_bands(self) -> None: with pytest.raises(ValueError, match="is an invalid band name."): BeninSmallHolderCashews(bands=("foo", "bar")) + + def test_plot(self, dataset: BeninSmallHolderCashews) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() + + def test_failed_plot(self, dataset: BeninSmallHolderCashews) -> None: + single_band_dataset = BeninSmallHolderCashews(root=dataset.root, bands=("B01",)) + with pytest.raises(ValueError, match="Dataset doesn't contain"): + x = single_band_dataset[0].copy() + single_band_dataset.plot(x, suptitle="Test") diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 2561eb9f8e5..84307a484a6 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -13,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import BigEarthNet, BigEarthNetDataModule +from torchgeo.datasets import BigEarthNet def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -149,25 +150,12 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match=err): BigEarthNet(str(tmp_path)) - -class TestBigEarthNetDataModule: - @pytest.fixture(scope="class", params=["s1", "s2", "all"]) - def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: - bands = request.param - root = os.path.join("tests", "data", "bigearthnet") - num_classes = 19 - batch_size = 1 - num_workers = 0 - dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None: - next(iter(datamodule.test_dataloader())) + def test_plot(self, dataset: BigEarthNet) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 573719ce9f3..1393fdb2bd4 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -19,7 +19,6 @@ BoundingBox, Chesapeake13, ChesapeakeCVPR, - ChesapeakeCVPRDataModule, IntersectionDataset, UnionDataset, ) @@ -102,6 +101,7 @@ class TestChesapeakeCVPR: ("naip-new", "naip-old", "nlcd"), ("landsat-leaf-on", "landsat-leaf-off", "lc"), ("naip-new", "landsat-leaf-on", "lc", "nlcd", "buildings"), + ("naip-new", "prior_from_cooccurrences_101_31_no_osm_no_buildings"), ] ) def dataset( @@ -113,12 +113,34 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.chesapeake, "download_url", download_url ) - md5 = "882d18b1f15ea4498bf54e674aecd5d4" - monkeypatch.setattr(ChesapeakeCVPR, "md5", md5) # type: ignore[attr-defined] - url = os.path.join( - "tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip" + monkeypatch.setattr( # type: ignore[attr-defined] + ChesapeakeCVPR, + "md5s", + { + "base": "882d18b1f15ea4498bf54e674aecd5d4", + "prior_extension": "677446c486f3145787938b14ee3da13f", + }, + ) + monkeypatch.setattr( # type: ignore[attr-defined] + ChesapeakeCVPR, + "urls", + { + "base": os.path.join( + "tests", + "data", + "chesapeake", + "cvpr", + "cvpr_chesapeake_landcover.zip", + ), + "prior_extension": os.path.join( + "tests", + "data", + "chesapeake", + "cvpr", + "cvpr_chesapeake_landcover_prior_extension.zip", + ), + }, ) - monkeypatch.setattr(ChesapeakeCVPR, "url", url) # type: ignore[attr-defined] monkeypatch.setattr( # type: ignore[attr-defined] ChesapeakeCVPR, "files", @@ -153,11 +175,23 @@ def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None: ChesapeakeCVPR(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - url = os.path.join( - "tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip" - ) root = str(tmp_path) - shutil.copy(url, root) + shutil.copy( + os.path.join( + "tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip" + ), + root, + ) + shutil.copy( + os.path.join( + "tests", + "data", + "chesapeake", + "cvpr", + "cvpr_chesapeake_landcover_prior_extension.zip", + ), + root, + ) ChesapeakeCVPR(root) def test_not_downloaded(self, tmp_path: Path) -> None: @@ -179,45 +213,3 @@ def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None: IndexError, match="query: .* spans multiple tiles which is not valid" ): ds[dataset.bounds] - - -class TestChesapeakeCVPRDataModule: - @pytest.fixture(scope="class", params=[5, 7]) - def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=32, - patches_per_tile=2, - batch_size=2, - num_workers=0, - class_set=request.param, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None: - next(iter(datamodule.test_dataloader())) - - def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None: - nodata_check = datamodule.nodata_check(4) - sample = { - "image": torch.ones(1, 2, 2), # type: ignore[attr-defined] - "mask": torch.ones(2, 2), # type: ignore[attr-defined] - } - out = nodata_check(sample) - assert torch.equal( # type: ignore[attr-defined] - out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined] - ) - assert torch.equal( # type: ignore[attr-defined] - out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined] - ) diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index 6ec7b533007..87bec26af27 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -14,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import COWCCounting, COWCCountingDataModule, COWCDetection +from torchgeo.datasets import COWCCounting, COWCDetection from torchgeo.datasets.cowc import COWC @@ -44,14 +45,14 @@ def dataset( COWCCounting, "base_url", base_url ) md5s = [ - "a729b6e29278a9a000aa349dad3c78cb", - "a8ff4c4de4b8c66bd9c5ec17f532b3a2", - "bc6b9493b8e39b87d189cadcc4823e05", - "f111948e2ac262c024c8fe32ba5b1434", - "8c333fcfa4168afa5376310958d15166", - "479670049aa9a48b4895cff6db3aa615", - "56043d4716ad0a1eedd392b0a543973b", - "b77193aef7c473379cd8d4e40d413137", + "7d0c6d1fb548d3ea3a182a56ce231f97", + "2e9a806b19b21f9d796c7393ad8f51ee", + "39453c0627effd908e773c5c1f8aecc9", + "67190b3e0ca8aa1fc93250aa5383a8f3", + "575aead6a0c92aba37d613895194da7c", + "e7c2279040d3ce31b9c925c45d0c61e2", + "f159e23d52bd0b5656fe296f427b98e1", + "0a4daed8c5f6c4e20faa6e38636e4346", ] monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) @@ -88,6 +89,16 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): COWCCounting(str(tmp_path)) + def test_plot(self, dataset: COWCCounting) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() + class TestCOWCDetection: @pytest.fixture(params=["train", "test"]) @@ -105,14 +116,14 @@ def dataset( COWCDetection, "base_url", base_url ) md5s = [ - "cc913824d9aa6c7af6f957dcc2cb9690", - "f8e07e70958d8d57ab464f62e9abab80", - "6a481cd785b0f16e9e1ab016a0695e57", - "e9578491977d291def2611b84c84fdfd", - "0bb1c285b170c23a8590cf2926fd224e", - "60fa485b16c0e5b28db756fd1d8a0438", - "97c886fb7558f4e8779628917ca64596", - "ab21a117b754e04e65c63f94aa648e33", + "6bbbdb36ee4922e879f66ed9234cb8ab", + "09e4af08c6e6553afe5098b328ce9749", + "12a2708ab7644766e43f5aae34aa7f2a", + "a896433398a0c58263c0d266cfc93bc4", + "911ed42c104db60f7a7d03a5b36bc1ab", + "4cdb4fefab6a2951591e7840c11a229d", + "dd315cfb48dfa7ddb8230c942682bc37", + "dccc2257e9c4a9dde2b4f84769804046", ] monkeypatch.setattr(COWCDetection, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) @@ -149,24 +160,12 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): COWCDetection(str(tmp_path)) - -class TestCOWCCountingDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> COWCCountingDataModule: - root = os.path.join("tests", "data", "cowc_counting") - seed = 0 - batch_size = 1 - num_workers = 0 - dm = COWCCountingDataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None: - next(iter(datamodule.test_dataloader())) + def test_plot(self, dataset: COWCDetection) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index 2c9f2cbbb85..fe8f2cf7bc1 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -113,3 +114,17 @@ def test_invalid_bands(self) -> None: with pytest.raises(ValueError, match="is an invalid band name."): CV4AKenyaCropType(bands=("foo", "bar")) + + def test_plot(self, dataset: CV4AKenyaCropType) -> None: + dataset.plot(dataset[0], time_step=0, suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, time_step=0, suptitle="Pred") + plt.close() + + def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None: + dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(["B01"])) + with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): + dataset.plot(dataset[0], time_step=0, suptitle="Single Band") diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index 6955143a1fb..c9bb803c856 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import CycloneDataModule, TropicalCycloneWindEstimation +from torchgeo.datasets import TropicalCycloneWindEstimation class Dataset: @@ -103,25 +103,3 @@ def test_plot(self, dataset: TropicalCycloneWindEstimation) -> None: ) dataset.plot(sample) plt.close() - - -class TestCycloneDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> CycloneDataModule: - root = os.path.join("tests", "data", "cyclone") - seed = 0 - batch_size = 1 - num_workers = 0 - dm = CycloneDataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: CycloneDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: CycloneDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: CycloneDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index 0aaee918be8..232fcd88976 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -14,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ETCI2021, ETCI2021DataModule +from torchgeo.datasets import ETCI2021 def download_url(url: str, root: str, *args: str) -> None: @@ -36,19 +36,19 @@ def dataset( metadata = { "train": { "filename": "train.zip", - "md5": "50c10eb07d6db9aee3ba36401e4a2c45", + "md5": "ebbd2e65cd10621bc2e90a230b474b8b", "directory": "train", "url": os.path.join(data_dir, "train.zip"), }, "val": { "filename": "val_with_ref_labels.zip", - "md5": "3e8b5a3cb95e6029e0e2c2d4b4ec6fba", + "md5": "efdd1fe6c90f5dfd267c88b86b237c2b", "directory": "test", "url": os.path.join(data_dir, "val_with_ref_labels.zip"), }, "test": { "filename": "test_without_ref_labels.zip", - "md5": "c8ee1e5d3e478761cd00ebc6f28b0ae7", + "md5": "bf1180143de5705fe95fa8490835d6d1", "directory": "test_internal", "url": os.path.join(data_dir, "test_without_ref_labels.zip"), }, @@ -95,25 +95,3 @@ def test_plot(self, dataset: ETCI2021) -> None: x["prediction"] = x["mask"][0].clone() dataset.plot(x) plt.close() - - -class TestETCI2021DataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> ETCI2021DataModule: - root = os.path.join("tests", "data", "etci2021") - seed = 0 - batch_size = 2 - num_workers = 0 - dm = ETCI2021DataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 008195bb72a..a8b47ea2561 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import EuroSAT, EuroSATDataModule +from torchgeo.datasets import EuroSAT def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -100,24 +100,3 @@ def test_plot(self, dataset: EuroSAT) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() - - -class TestEuroSATDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> EuroSATDataModule: - root = os.path.join("tests", "data", "eurosat") - batch_size = 1 - num_workers = 0 - dm = EuroSATDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index 9b175c649a0..3f188ebb6e1 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -12,7 +12,7 @@ import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import FAIR1M, FAIR1MDataModule +from torchgeo.datasets import FAIR1M class TestFAIR1M: @@ -73,25 +73,3 @@ def test_plot(self, dataset: FAIR1M) -> None: x["prediction_boxes"] = x["boxes"].clone() dataset.plot(x) plt.close() - - -class TestFAIR1MDataModule: - @pytest.fixture(scope="class", params=[True, False]) - def datamodule(self) -> FAIR1MDataModule: - root = os.path.join("tests", "data", "fair1m") - batch_size = 2 - num_workers = 0 - dm = FAIR1MDataModule( - root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33 - ) - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 8a6942e5800..e971f88ae7f 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import LandCoverAI, LandCoverAIDataModule +from torchgeo.datasets import LandCoverAI def download_url(url: str, root: str, *args: str) -> None: @@ -78,24 +78,3 @@ def test_plot(self, dataset: LandCoverAI) -> None: x["prediction"] = x["mask"].clone() dataset.plot(x) plt.close() - - -class TestLandCoverAIDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> LandCoverAIDataModule: - root = os.path.join("tests", "data", "landcoverai") - batch_size = 2 - num_workers = 0 - dm = LandCoverAIDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index 2aca6c8b0c5..f61bc241be8 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -31,7 +32,7 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.utils, "download_url", download_url ) - md5 = "b61c300e9fd7146eb2c8e2512c0e9d39" + md5 = "1adf156f628aa32fb2e8fe6cada16c04" monkeypatch.setattr(LEVIRCDPlus, "md5", md5) # type: ignore[attr-defined] url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip") monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined] @@ -60,3 +61,12 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): LEVIRCDPlus(str(tmp_path)) + + def test_plot(self, dataset: LEVIRCDPlus) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, suptitle="Prediction") + plt.close() diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py index 0bfca7bc6c2..e445ae9d3d4 100644 --- a/tests/datasets/test_loveda.py +++ b/tests/datasets/test_loveda.py @@ -14,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import LoveDA, LoveDADataModule +from torchgeo.datasets import LoveDA def download_url(url: str, root: str, *args: str) -> None: @@ -99,29 +99,3 @@ def test_not_downloaded(self, tmp_path: Path) -> None: def test_plot(self, dataset: LoveDA) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() - - -class TestLoveDADataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> LoveDADataModule: - root = os.path.join("tests", "data", "loveda") - batch_size = 2 - num_workers = 0 - scene = ["rural", "urban"] - - dm = LoveDADataModule( - root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers - ) - - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: LoveDADataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index 4ef17e7cfc7..2089d09ac45 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -12,13 +12,7 @@ from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS -from torchgeo.datasets import ( - NAIP, - BoundingBox, - IntersectionDataset, - NAIPChesapeakeDataModule, - UnionDataset, -) +from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset class TestNAIP: @@ -60,27 +54,3 @@ def test_invalid_query(self, dataset: NAIP) -> None: IndexError, match="query: .* not found in index with bounds:" ): dataset[query] - - -class TestNAIPChesapeakeDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> NAIPChesapeakeDataModule: - dm = NAIPChesapeakeDataModule( - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "chesapeake", "BAYWIDE"), - batch_size=2, - num_workers=0, - ) - dm.patch_size = 32 - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index a8b20c2cd29..deb8366ddfd 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -13,7 +13,7 @@ import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import NASAMarineDebris, NASAMarineDebrisDataModule +from torchgeo.datasets import NASAMarineDebris class Dataset: @@ -85,28 +85,3 @@ def test_plot(self, dataset: NASAMarineDebris) -> None: x["prediction_boxes"] = x["boxes"].clone() dataset.plot(x) plt.close() - - -class TestNASAMarineDebrisDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> NASAMarineDebrisDataModule: - root = os.path.join("tests", "data", "nasa_marine_debris") - batch_size = 2 - num_workers = 0 - val_split_pct = 0.3 - test_split_pct = 0.3 - dm = NASAMarineDebrisDataModule( - root, batch_size, num_workers, val_split_pct, test_split_pct - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index 2bfaf25d50d..a8e497cbc8d 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -16,7 +16,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import OSCD, OSCDDataModule +from torchgeo.datasets import OSCD def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -34,8 +34,18 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.oscd, "download_url", download_url ) - md5 = "d6ebaae1ea0f3ae960af31531d394521" - monkeypatch.setattr(OSCD, "md5", md5) # type: ignore[attr-defined] + md5s = { + "Onera Satellite Change Detection dataset - Images.zip": ( + "fb4e3f54c3a31fd3f21f98cad4ddfb74" + ), + "Onera Satellite Change Detection dataset - Train Labels.zip": ( + "ca526434a60e9abdf97d528dc29e9f13" + ), + "Onera Satellite Change Detection dataset - Test Labels.zip": ( + "ca0ba73ba66d06fa4903e269ef12eb50" + ), + } + monkeypatch.setattr(OSCD, "md5s", md5s) # type: ignore[attr-defined] urls = { "Onera Satellite Change Detection dataset - Images.zip": os.path.join( "tests", @@ -105,56 +115,3 @@ def test_not_downloaded(self, tmp_path: Path) -> None: def test_plot(self, dataset: OSCD) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() - - -class TestOSCDDataModule: - @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) - def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands, val_split_pct = request.param - patch_size = (2, 2) - num_patches_per_tile = 2 - root = os.path.join("tests", "data", "oscd") - batch_size = 1 - num_workers = 0 - dm = OSCDDataModule( - root, - bands, - batch_size, - num_workers, - val_split_pct, - patch_size, - num_patches_per_tile, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.train_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 - else: - assert sample["image"].shape[1] == 6 - - def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.val_dataloader())) - if datamodule.val_split_pct > 0.0: - assert ( - sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - ) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 - else: - assert sample["image"].shape[1] == 6 - - def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 - else: - assert sample["image"].shape[1] == 6 diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py index b11d0dc138e..6a298baf359 100644 --- a/tests/datasets/test_potsdam.py +++ b/tests/datasets/test_potsdam.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import Potsdam2D, Potsdam2DDataModule +from torchgeo.datasets import Potsdam2D class TestPotsdam2D: @@ -75,27 +75,3 @@ def test_plot(self, dataset: Potsdam2D) -> None: x["prediction"] = x["mask"].clone() dataset.plot(x) plt.close() - - -class TestPotsdam2DDataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> Potsdam2DDataModule: - root = os.path.join("tests", "data", "potsdam") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = Potsdam2DDataModule( - root, batch_size, num_workers, val_split_pct=val_split_size - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index 75ed6ee2d58..c8f4d9157d1 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import RESISC45, RESISC45DataModule +from torchgeo.datasets import RESISC45 def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -101,24 +101,3 @@ def test_plot(self, dataset: RESISC45) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() - - -class TestRESISC45DataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> RESISC45DataModule: - root = os.path.join("tests", "data", "resisc45") - batch_size = 2 - num_workers = 0 - dm = RESISC45DataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 2332e70f39d..ee8969b22b1 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Generator +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -12,7 +13,7 @@ from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset -from torchgeo.datasets import SEN12MS, SEN12MSDataModule +from torchgeo.datasets import SEN12MS class TestSEN12MS: @@ -83,25 +84,20 @@ def test_band_subsets(self) -> None: x = ds[0]["image"] assert x.shape[0] == len(bands) + def test_invalid_bands(self) -> None: + with pytest.raises(ValueError): + SEN12MS(bands=tuple(["OK", "BK"])) -class TestSEN12MSDataModule: - @pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"]) - def datamodule(self, request: SubRequest) -> SEN12MSDataModule: - root = os.path.join("tests", "data", "sen12ms") - seed = 0 - bands = request.param - batch_size = 1 - num_workers = 0 - dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None: - next(iter(datamodule.test_dataloader())) + def test_plot(self, dataset: SEN12MS) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["mask"].clone() + dataset.plot(sample, suptitle="prediction") + plt.close() + + def test_plot_rgb(self, dataset: SEN12MS) -> None: + dataset = SEN12MS(root=dataset.root, bands=tuple(["B03"])) + with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): + dataset.plot(dataset[0], suptitle="Single Band") diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index ab4085ba5e8..5229ea14bb4 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import So2Sat, So2SatDataModule +from torchgeo.datasets import So2Sat pytest.importorskip("h5py") @@ -24,9 +24,9 @@ def dataset( self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest ) -> So2Sat: md5s = { - "train": "2fa6b9d8995e3b6272af42719f05aaa2", - "validation": "fe3dbf74971766d5038f6cbc0b1390ae", - "test": "87d428eff44267ca642fc739cc442331", + "train": "82e0f2d51766b89cb905dbaf8275eb5b", + "validation": "bf292ae4737c1698b1a3c6f5e742e0e1", + "test": "9a3bbe181b038d4e51f122c4be3c569e", } monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined] @@ -57,13 +57,13 @@ def test_getitem(self, dataset: So2Sat) -> None: assert isinstance(x["label"], torch.Tensor) def test_len(self, dataset: So2Sat) -> None: - assert len(dataset) == 10 + assert len(dataset) == 1 def test_out_of_bounds(self, dataset: So2Sat) -> None: # h5py at version 2.10.0 raises a ValueError instead of an IndexError so we # check for both here with pytest.raises((IndexError, ValueError)): - dataset[10] + dataset[1] def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -91,25 +91,3 @@ def test_mock_missing_module( match="h5py is not installed and is required to use this dataset", ): So2Sat(dataset.root) - - -class TestSo2SatDataModule: - @pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"])) - def datamodule(self, request: SubRequest) -> So2SatDataModule: - unsupervised_mode, bands = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: So2SatDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: So2SatDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: So2SatDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index ad6efb6628b..600c2595d4d 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -15,7 +15,7 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import UCMerced, UCMercedDataModule +from torchgeo.datasets import UCMerced def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -102,24 +102,3 @@ def test_plot(self, dataset: UCMerced) -> None: x["prediction"] = x["label"].clone() dataset.plot(x) plt.close() - - -class TestUCMercedDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> UCMercedDataModule: - root = os.path.join("tests", "data", "ucmerced") - batch_size = 2 - num_workers = 0 - dm = UCMercedDataModule(root, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 9e8732deac4..631897bb7f2 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -18,13 +18,11 @@ import torch from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS -from torch.utils.data import TensorDataset import torchgeo.datasets.utils from torchgeo.datasets.utils import ( BoundingBox, concat_samples, - dataset_split, disambiguate_timestamp, download_and_extract_archive, download_radiant_mlhub_collection, @@ -563,24 +561,6 @@ def test_nonexisting_directory(tmp_path: Path) -> None: assert subdir.cwd() == subdir -def test_dataset_split() -> None: - num_samples = 24 - x = torch.ones(num_samples, 5) # type: ignore[attr-defined] - y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined] - ds = TensorDataset(x, y) - - # Test only train/val set split - train_ds, val_ds = dataset_split(ds, val_pct=1 / 2) - assert len(train_ds) == num_samples // 2 - assert len(val_ds) == num_samples // 2 - - # Test train/val/test set split - train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3) - assert len(train_ds) == num_samples // 3 - assert len(val_ds) == num_samples // 3 - assert len(test_ds) == num_samples // 3 - - def test_percentile_normalization() -> None: img = np.array([[1, 2], [98, 100]]) diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py index 531dd24e592..033017ea0ee 100644 --- a/tests/datasets/test_vaihingen.py +++ b/tests/datasets/test_vaihingen.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import Vaihingen2D, Vaihingen2DDataModule +from torchgeo.datasets import Vaihingen2D class TestVaihingen2D: @@ -84,27 +84,3 @@ def test_plot(self, dataset: Vaihingen2D) -> None: x["prediction"] = x["mask"].clone() dataset.plot(x) plt.close() - - -class TestVaihingen2DDataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule: - root = os.path.join("tests", "data", "vaihingen") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = Vaihingen2DDataModule( - root, batch_size, num_workers, val_split_pct=val_split_size - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index b358337617c..92e00f4c7fd 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import XView2, XView2DataModule +from torchgeo.datasets import XView2 class TestXView2: @@ -95,27 +95,3 @@ def test_plot(self, dataset: XView2) -> None: x["prediction"] = x["mask"][0].clone() dataset.plot(x) plt.close() - - -class TestXView2DataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> XView2DataModule: - root = os.path.join("tests", "data", "xview2") - batch_size = 1 - num_workers = 0 - val_split_size = request.param - dm = XView2DataModule( - root, batch_size, num_workers, val_split_pct=val_split_size - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader(self, datamodule: XView2DataModule) -> None: - next(iter(datamodule.train_dataloader())) - - def test_val_dataloader(self, datamodule: XView2DataModule) -> None: - next(iter(datamodule.val_dataloader())) - - def test_test_dataloader(self, datamodule: XView2DataModule) -> None: - next(iter(datamodule.test_dataloader())) diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py new file mode 100644 index 00000000000..5b7f7a925cc --- /dev/null +++ b/tests/losses/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. diff --git a/tests/losses/test_qr.py b/tests/losses/test_qr.py new file mode 100644 index 00000000000..ce6ab19254e --- /dev/null +++ b/tests/losses/test_qr.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import torch + +from torchgeo.losses import QRLoss, RQLoss + + +class TestQRLosses: + def test_loss_on_prior_simple(self) -> None: + probs = torch.rand(2, 4, 10, 10) + log_probs = torch.log(probs) # type: ignore[attr-defined] + targets = torch.rand(2, 4, 10, 10) + QRLoss()(log_probs, targets) + + def test_loss_on_prior_reversed_kl_simple(self) -> None: + probs = torch.rand(2, 4, 10, 10) + log_probs = torch.log(probs) # type: ignore[attr-defined] + targets = torch.rand(2, 4, 10, 10) + RQLoss()(log_probs, targets) diff --git a/tests/test_train.py b/tests/test_train.py index 786dd6ad9fe..5b853caa130 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -132,36 +132,3 @@ def test_config_file(tmp_path: Path) -> None: ) args = [sys.executable, "train.py", "config_file=" + str(config_file)] subprocess.run(args, check=True) - - -@pytest.mark.parametrize( - "task", - [ - "bigearthnet", - "byol", - "chesapeake_cvpr", - "cowc_counting", - "cyclone", - "landcoverai", - "naipchesapeake", - "resisc45", - "sen12ms", - "so2sat", - "ucmerced", - ], -) -def test_tasks(task: str, tmp_path: Path) -> None: - output_dir = tmp_path / "output" - log_dir = tmp_path / "logs" - args = [ - sys.executable, - "train.py", - "experiment.name=test", - "program.output_dir=" + str(output_dir), - "program.log_dir=" + str(log_dir), - "trainer.fast_dev_run=1", - "experiment.task=" + task, - "program.overwrite=True", - "config_file=" + os.path.join("conf", "task_defaults", task + ".yaml"), - ] - subprocess.run(args, check=True) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index ac5e9e2b792..bb661d185ad 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -2,22 +2,18 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, cast +from typing import Any, Dict, Type, cast import pytest import torch.nn as nn -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning import LightningDataModule, Trainer from torchvision.models import resnet18 -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation -from .test_utils import mocked_log - class TestBYOL: def test_custom_augment_fn(self) -> None: @@ -37,61 +33,37 @@ def test_custom_augment_fn(self) -> None: class TestBYOLTask: - @pytest.fixture(scope="class") - def datamodule(self) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=4, - patches_per_tile=2, - batch_size=2, - num_workers=0, - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture(params=["resnet18", "resnet50"]) - def config(self, request: SubRequest) -> Dict[str, Any]: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "byol.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["encoder"] = request.param - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> LightningModule: - task = BYOLTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: BYOLTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out + @pytest.mark.parametrize( + "name,classname", + [ + ("chesapeake_cvpr_7", ChesapeakeCVPRDataModule), + ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), + ], + ) + def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) - def test_training( - self, datamodule: ChesapeakeCVPRDataModule, task: BYOLTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) - def test_validation( - self, datamodule: ChesapeakeCVPRDataModule, task: BYOLTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) + # Instantiate model + model_kwargs = conf_dict["module"] + model = BYOLTask(**model_kwargs) - def test_test(self, datamodule: ChesapeakeCVPRDataModule, task: BYOLTask) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) + # Instantiate trainer + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + trainer.test(model=model, datamodule=datamodule) - def test_invalid_encoder(self, config: Dict[str, Any]) -> None: - config["encoder"] = "invalid_encoder" + def test_invalid_encoder(self) -> None: + kwargs = { + "in_channels": 1, + "imagenet_pretraining": False, + "encoder_name": "invalid_encoder", + } error_message = "Encoder type 'invalid_encoder' is not valid." with pytest.raises(ValueError, match=error_message): - BYOLTask(**config) + BYOLTask(**kwargs) diff --git a/tests/trainers/test_chesapeake.py b/tests/trainers/test_chesapeake.py index a9c95907dbc..377d4a85e35 100644 --- a/tests/trainers/test_chesapeake.py +++ b/tests/trainers/test_chesapeake.py @@ -9,7 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask from .test_utils import FakeTrainer, mocked_log @@ -40,11 +40,10 @@ def datamodule(self, class_set: int) -> ChesapeakeCVPRDataModule: @pytest.fixture def config(self, class_set: int) -> Dict[str, Any]: task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") + os.path.join("conf", "task_defaults", f"chesapeake_cvpr_{class_set}.yaml") ) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - task_args["num_classes"] = class_set return task_args @pytest.fixture diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 4033f7bf5ee..f2c091c75ee 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -2,250 +2,136 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Optional, cast +from typing import Any, Dict, Type, cast import pytest -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torch import Tensor -from torch.utils.data import DataLoader, Dataset, TensorDataset - +from pytorch_lightning import LightningDataModule, Trainer + +from torchgeo.datamodules import ( + BigEarthNetDataModule, + EuroSATDataModule, + RESISC45DataModule, + So2SatDataModule, + UCMercedDataModule, +) from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask -from .test_utils import mocked_log - -class DummyDataset(Dataset): # type: ignore[type-arg] - def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None: - x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w) - y = torch.randint( # type: ignore[attr-defined] - 0, num_classes, size=(10,) - ) # (b,) +class TestClassificationTask: + @pytest.mark.parametrize( + "name,classname", + [ + ("eurosat", EuroSATDataModule), + ("resisc45", RESISC45DataModule), + ("so2sat_supervised", So2SatDataModule), + ("so2sat_unsupervised", So2SatDataModule), + ("ucmerced", UCMercedDataModule), + ], + ) + def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + if name == "so2sat": + pytest.importorskip("h5py") - if multilabel: - y = F.one_hot(y, num_classes=num_classes) # (b, classes) + conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) - self.dataset = TensorDataset(x, y) + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) - def __len__(self) -> int: - return len(self.dataset) + # Instantiate model + model_kwargs = conf_dict["module"] + model = ClassificationTask(**model_kwargs) - def __getitem__(self, idx: int) -> Dict[str, Tensor]: - x, y = self.dataset[idx] - sample = {"image": x, "label": y} - return sample + # Instantiate trainer + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + trainer.test(model=model, datamodule=datamodule) + @pytest.fixture + def model_kwargs(self) -> Dict[Any, Any]: + return { + "classification_model": "resnet18", + "in_channels": 1, + "loss": "ce", + "num_classes": 1, + "weights": "random", + } + + def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + model_kwargs["weights"] = checkpoint + with pytest.warns(UserWarning): + ClassificationTask(**model_kwargs) -class DummyDataModule(pl.LightningDataModule): - def __init__( - self, - num_channels: int, - num_classes: int, - multilabel: bool, - batch_size: int = 1, - num_workers: int = 0, + def test_invalid_pretrained( + self, model_kwargs: Dict[Any, Any], checkpoint: str ) -> None: - super().__init__() # type: ignore[no-untyped-call] - self.num_channels = num_channels - self.num_classes = num_classes - self.multilabel = multilabel - self.batch_size = batch_size - self.num_workers = num_workers - - def setup(self, stage: Optional[str] = None) -> None: - self.dataset = DummyDataset( - num_channels=self.num_channels, - num_classes=self.num_classes, - multilabel=self.multilabel, - ) - - def train_dataloader(self) -> DataLoader: # type: ignore[type-arg] - return DataLoader( - self.dataset, batch_size=self.batch_size, num_workers=self.num_workers - ) - - def val_dataloader(self) -> DataLoader: # type: ignore[type-arg] - return DataLoader( - self.dataset, batch_size=self.batch_size, num_workers=self.num_workers - ) - - def test_dataloader(self) -> DataLoader: # type: ignore[type-arg] - return DataLoader( - self.dataset, batch_size=self.batch_size, num_workers=self.num_workers - ) - + model_kwargs["weights"] = checkpoint + model_kwargs["classification_model"] = "resnet50" + match = "Trying to load resnet18 weights into a resnet50" + with pytest.raises(ValueError, match=match): + ClassificationTask(**model_kwargs) + + def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["loss"] = "invalid_loss" + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + ClassificationTask(**model_kwargs) + + def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["classification_model"] = "invalid_model" + match = "Model type 'invalid_model' is not a valid timm model." + with pytest.raises(ValueError, match=match): + ClassificationTask(**model_kwargs) + + def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["weights"] = "invalid_weights" + match = "Weight type 'invalid_weights' is not valid." + with pytest.raises(ValueError, match=match): + ClassificationTask(**model_kwargs) -class TestClassificationTask: - num_classes = 10 - - @pytest.fixture(scope="class", params=[2, 3, 5]) - def datamodule(self, request: SubRequest) -> DummyDataModule: - dm = DummyDataModule( - num_channels=request.param, - num_classes=self.num_classes, - multilabel=False, - batch_size=2, - num_workers=0, - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture( - scope="class", - params=zip( - ["ce", "jaccard", "focal"], - ["random", "random", "random"], - ["resnet18", "hrnet_w18_small_v2", "tf_efficientnet_b0"], - ), +class TestMultiLabelClassificationTask: + @pytest.mark.parametrize( + "name,classname", + [ + ("bigearthnet_all", BigEarthNetDataModule), + ("bigearthnet_s1", BigEarthNetDataModule), + ("bigearthnet_s2", BigEarthNetDataModule), + ], ) - def config( - self, request: SubRequest, datamodule: DummyDataModule - ) -> Dict[str, Any]: - loss, weights, model = request.param - task_args: Dict[str, Any] = {} - task_args["classification_model"] = model - task_args["learning_rate"] = 3e-4 - task_args["learning_rate_schedule_patience"] = 6 - task_args["in_channels"] = datamodule.num_channels - task_args["loss"] = loss - task_args["num_classes"] = self.num_classes - task_args["weights"] = weights - return task_args + def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> ClassificationTask: - task = ClassificationTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: ClassificationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: DummyDataModule, task: ClassificationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) - def test_validation( - self, datamodule: DummyDataModule, task: ClassificationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_pretrained(self, checkpoint: str) -> None: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["weights"] = checkpoint - with pytest.warns(UserWarning): - ClassificationTask(**task_args) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["classification_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not a valid timm model." - with pytest.raises(ValueError, match=error_message): - ClassificationTask(**config) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - config["classification_model"] = "resnet18" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - ClassificationTask(**config) - - def test_invalid_weights(self, config: Dict[str, Any]) -> None: - config["weights"] = "invalid_weights" - error_message = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=error_message): - ClassificationTask(**config) - - def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> None: - config["weights"] = checkpoint - config["classification_model"] = "resnet50" - error_message = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=error_message): - ClassificationTask(**config) + # Instantiate model + model_kwargs = conf_dict["module"] + model = MultiLabelClassificationTask(**model_kwargs) - -class TestMultiLabelClassificationTask: - - num_classes = 10 - - @pytest.fixture(scope="class") - def datamodule(self, request: SubRequest) -> DummyDataModule: - dm = DummyDataModule( - num_channels=3, - num_classes=self.num_classes, - multilabel=True, - batch_size=2, - num_workers=0, - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture(scope="class", params=zip(["bce", "bce"], ["random", "random"])) - def config( - self, datamodule: DummyDataModule, request: SubRequest - ) -> Dict[str, Any]: - task_args: Dict[str, Any] = {} - task_args["classification_model"] = "resnet18" - task_args["learning_rate"] = 3e-4 - task_args["learning_rate_schedule_patience"] = 6 - task_args["in_channels"] = datamodule.num_channels - loss, weights = request.param - task_args["loss"] = loss - task_args["num_classes"] = self.num_classes - task_args["weights"] = weights - return task_args + # Instantiate trainer + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + trainer.test(model=model, datamodule=datamodule) @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> MultiLabelClassificationTask: - task = MultiLabelClassificationTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_training( - self, datamodule: DummyDataModule, task: ClassificationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: DummyDataModule, task: ClassificationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - MultiLabelClassificationTask(**config) + def model_kwargs(self) -> Dict[Any, Any]: + return { + "classification_model": "resnet18", + "in_channels": 1, + "loss": "ce", + "num_classes": 1, + "weights": "random", + } + + def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["loss"] = "invalid_loss" + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + MultiLabelClassificationTask(**model_kwargs) diff --git a/tests/trainers/test_landcoverai.py b/tests/trainers/test_landcoverai.py index b14d5b46684..d3e70dfb098 100644 --- a/tests/trainers/test_landcoverai.py +++ b/tests/trainers/test_landcoverai.py @@ -8,7 +8,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import LandCoverAIDataModule +from torchgeo.datamodules import LandCoverAIDataModule from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_naipchesapeake.py b/tests/trainers/test_naipchesapeake.py index 3b8cce5aca0..37d94cb0ed8 100644 --- a/tests/trainers/test_naipchesapeake.py +++ b/tests/trainers/test_naipchesapeake.py @@ -8,7 +8,7 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.datasets import NAIPChesapeakeDataModule +from torchgeo.datamodules import NAIPChesapeakeDataModule from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index cfa7e16924b..574c930e7d6 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -2,73 +2,40 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, cast +from typing import Any, Dict, Type, cast import pytest -from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf +from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datasets import CycloneDataModule +from torchgeo.datamodules import COWCCountingDataModule, CycloneDataModule from torchgeo.trainers import RegressionTask -from .test_utils import mocked_log - class TestRegressionTask: - @pytest.fixture(scope="class") - def datamodule(self) -> CycloneDataModule: - root = os.path.join("tests", "data", "cyclone") - seed = 0 - batch_size = 1 - num_workers = 0 - dm = CycloneDataModule(root, seed, batch_size, num_workers) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture - def config(self) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "cyclone.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - return task_args - - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> RegressionTask: - task = RegressionTask(**config) - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task - - def test_configure_optimizers(self, task: RegressionTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out - - def test_training( - self, datamodule: CycloneDataModule, task: RegressionTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) - - def test_validation( - self, datamodule: CycloneDataModule, task: RegressionTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test(self, datamodule: CycloneDataModule, task: RegressionTask) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) - - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - RegressionTask(**config) + @pytest.mark.parametrize( + "name,classname", + [("cowc_counting", COWCCountingDataModule), ("cyclone", CycloneDataModule)], + ) + def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) + + # Instantiate model + model_kwargs = conf_dict["module"] + model = RegressionTask(**model_kwargs) + + # Instantiate trainer + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + trainer.test(model=model, datamodule=datamodule) + + def test_invalid_model(self) -> None: + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + RegressionTask(model="invalid_model") diff --git a/tests/trainers/test_resisc45.py b/tests/trainers/test_resisc45.py index 0b832295faf..1eec36e2fee 100644 --- a/tests/trainers/test_resisc45.py +++ b/tests/trainers/test_resisc45.py @@ -7,7 +7,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import RESISC45DataModule +from torchgeo.datamodules import RESISC45DataModule from torchgeo.trainers.resisc45 import RESISC45ClassificationTask from .test_utils import FakeTrainer, mocked_log diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 058f94170b7..1d13299334e 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -2,95 +2,77 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, cast +from typing import Any, Dict, Type, cast import pytest -from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf +from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datasets import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ( + ChesapeakeCVPRDataModule, + ETCI2021DataModule, + LandCoverAIDataModule, + NAIPChesapeakeDataModule, + OSCDDataModule, + SEN12MSDataModule, +) from torchgeo.trainers import SemanticSegmentationTask -from .test_utils import FakeTrainer, mocked_log - class TestSemanticSegmentationTask: - @pytest.fixture(scope="class") - def datamodule(self) -> ChesapeakeCVPRDataModule: - dm = ChesapeakeCVPRDataModule( - os.path.join("tests", "data", "chesapeake", "cvpr"), - ["de-test"], - ["de-test"], - ["de-test"], - patch_size=32, - patches_per_tile=2, - batch_size=2, - num_workers=0, - class_set=7, - ) - dm.prepare_data() - dm.setup() - return dm - - @pytest.fixture( - params=zip(["unet", "deeplabv3+", "fcn"], ["ce", "jaccard", "focal"]) + @pytest.mark.parametrize( + "name,classname", + [ + ("chesapeake_cvpr_5", ChesapeakeCVPRDataModule), + ("etci2021", ETCI2021DataModule), + ("landcoverai", LandCoverAIDataModule), + ("naipchesapeake", NAIPChesapeakeDataModule), + ("oscd_all", OSCDDataModule), + ("oscd_rgb", OSCDDataModule), + ("sen12ms_all", SEN12MSDataModule), + ("sen12ms_s1", SEN12MSDataModule), + ("sen12ms_s2_all", SEN12MSDataModule), + ("sen12ms_s2_reduced", SEN12MSDataModule), + ], ) - def config(self, request: SubRequest) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "chesapeake_cvpr.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - segmentation_model, loss = request.param - task_args["segmentation_model"] = segmentation_model - task_args["loss"] = loss - return task_args + def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml")) + conf_dict = OmegaConf.to_object(conf.experiment) + conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) - @pytest.fixture - def task( - self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] - ) -> SemanticSegmentationTask: - task = SemanticSegmentationTask(**config) - trainer = FakeTrainer() - monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined] - monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] - return task + # Instantiate datamodule + datamodule_kwargs = conf_dict["datamodule"] + datamodule = classname(**datamodule_kwargs) - def test_configure_optimizers(self, task: SemanticSegmentationTask) -> None: - out = task.configure_optimizers() - assert "optimizer" in out - assert "lr_scheduler" in out + # Instantiate model + model_kwargs = conf_dict["module"] + model = SemanticSegmentationTask(**model_kwargs) - def test_training( - self, datamodule: ChesapeakeCVPRDataModule, task: SemanticSegmentationTask - ) -> None: - batch = next(iter(datamodule.train_dataloader())) - task.training_step(batch, 0) - task.training_epoch_end(0) + # Instantiate trainer + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1) + trainer.fit(model=model, datamodule=datamodule) + trainer.test(model=model, datamodule=datamodule) - def test_validation( - self, datamodule: ChesapeakeCVPRDataModule, task: SemanticSegmentationTask - ) -> None: - batch = next(iter(datamodule.val_dataloader())) - task.validation_step(batch, 0) - task.validation_epoch_end(0) - - def test_test( - self, datamodule: ChesapeakeCVPRDataModule, task: SemanticSegmentationTask - ) -> None: - batch = next(iter(datamodule.test_dataloader())) - task.test_step(batch, 0) - task.test_epoch_end(0) + @pytest.fixture + def model_kwargs(self) -> Dict[Any, Any]: + return { + "segmentation_model": "unet", + "encoder_name": "resnet18", + "encoder_weights": None, + "in_channels": 1, + "num_classes": 1, + "loss": "ce", + "ignore_zeros": True, + } - def test_invalid_model(self, config: Dict[str, Any]) -> None: - config["segmentation_model"] = "invalid_model" - error_message = "Model type 'invalid_model' is not valid." - with pytest.raises(ValueError, match=error_message): - SemanticSegmentationTask(**config) + def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["segmentation_model"] = "invalid_model" + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + SemanticSegmentationTask(**model_kwargs) - def test_invalid_loss(self, config: Dict[str, Any]) -> None: - config["loss"] = "invalid_loss" - error_message = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=error_message): - SemanticSegmentationTask(**config) + def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["loss"] = "invalid_loss" + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + SemanticSegmentationTask(**model_kwargs) diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py index cc1b2abb965..0f79c7e5e59 100644 --- a/torchgeo/__init__.py +++ b/torchgeo/__init__.py @@ -9,55 +9,6 @@ The :mod:`torchgeo` package consists of popular datasets, model architectures, and common image transformations for geospatial data. """ -from typing import Dict, Tuple, Type - -import pytorch_lightning as pl - -from .datasets import ( - BigEarthNetDataModule, - ChesapeakeCVPRDataModule, - COWCCountingDataModule, - CycloneDataModule, - ETCI2021DataModule, - EuroSATDataModule, - LandCoverAIDataModule, - NAIPChesapeakeDataModule, - OSCDDataModule, - RESISC45DataModule, - SEN12MSDataModule, - So2SatDataModule, - UCMercedDataModule, -) -from .trainers import ( - BYOLTask, - ClassificationTask, - MultiLabelClassificationTask, - RegressionTask, - SemanticSegmentationTask, -) -from .trainers.chesapeake import ChesapeakeCVPRSegmentationTask -from .trainers.landcoverai import LandCoverAISegmentationTask -from .trainers.naipchesapeake import NAIPChesapeakeSegmentationTask -from .trainers.resisc45 import RESISC45ClassificationTask __author__ = "Adam J. Stewart" __version__ = "0.2.0.dev0" - -_TASK_TO_MODULES_MAPPING: Dict[ - str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] -] = { - "bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), - "byol": (BYOLTask, ChesapeakeCVPRDataModule), - "chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), - "cowc_counting": (RegressionTask, COWCCountingDataModule), - "cyclone": (RegressionTask, CycloneDataModule), - "eurosat": (ClassificationTask, EuroSATDataModule), - "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), - "landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule), - "naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule), - "oscd": (SemanticSegmentationTask, OSCDDataModule), - "resisc45": (RESISC45ClassificationTask, RESISC45DataModule), - "sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), - "so2sat": (ClassificationTask, So2SatDataModule), - "ucmerced": (ClassificationTask, UCMercedDataModule), -} diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py new file mode 100644 index 00000000000..e09fe0ab378 --- /dev/null +++ b/torchgeo/datamodules/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TorchGeo datamodules.""" + +from .bigearthnet import BigEarthNetDataModule +from .chesapeake import ChesapeakeCVPRDataModule +from .cowc import COWCCountingDataModule +from .cyclone import CycloneDataModule +from .etci2021 import ETCI2021DataModule +from .eurosat import EuroSATDataModule +from .fair1m import FAIR1MDataModule +from .landcoverai import LandCoverAIDataModule +from .loveda import LoveDADataModule +from .naip import NAIPChesapeakeDataModule +from .nasa_marine_debris import NASAMarineDebrisDataModule +from .oscd import OSCDDataModule +from .potsdam import Potsdam2DDataModule +from .resisc45 import RESISC45DataModule +from .sen12ms import SEN12MSDataModule +from .so2sat import So2SatDataModule +from .ucmerced import UCMercedDataModule +from .vaihingen import Vaihingen2DDataModule +from .xview import XView2DataModule + +__all__ = ( + # GeoDataset + "ChesapeakeCVPRDataModule", + "NAIPChesapeakeDataModule", + # VisionDataset + "BigEarthNetDataModule", + "COWCCountingDataModule", + "ETCI2021DataModule", + "EuroSATDataModule", + "FAIR1MDataModule", + "LandCoverAIDataModule", + "LoveDADataModule", + "NASAMarineDebrisDataModule", + "OSCDDataModule", + "Potsdam2DDataModule", + "RESISC45DataModule", + "SEN12MSDataModule", + "So2SatDataModule", + "CycloneDataModule", + "UCMercedDataModule", + "Vaihingen2DDataModule", + "XView2DataModule", +) + +# https://stackoverflow.com/questions/40018681 +for module in __all__: + globals()[module].__module__ = "torchgeo.datamodules" diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py new file mode 100644 index 00000000000..11c2e4ed9ab --- /dev/null +++ b/torchgeo/datamodules/bigearthnet.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""BigEarthNet datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import BigEarthNet + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class BigEarthNetDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the BigEarthNet dataset. + + Uses the train/val/test splits from the dataset. + """ + + # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) + # min/max band statistics computed on 100k random samples + band_mins_raw = torch.tensor( # type: ignore[attr-defined] + [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] + ) + band_maxs_raw = torch.tensor( # type: ignore[attr-defined] + [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18976.0, + 17874.0, + 16611.0, + 16512.0, + 16394.0, + 16672.0, + 16141.0, + 16097.0, + 15336.0, + 15203.0, + ] + ) + + # min/max band statistics computed by percentile clipping the + # above to samples to [2, 98] + band_mins = torch.tensor( # type: ignore[attr-defined] + [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + band_maxs = torch.tensor( # type: ignore[attr-defined] + [ + 6.0, + 16.0, + 9859.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, + ] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + num_classes: int = 19, + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes + bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} + num_classes: number of classes to load in target. one of {19, 43} + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.num_classes = num_classes + self.batch_size = batch_size + self.num_workers = num_workers + + if bands == "all": + self.mins = self.band_mins[:, None, None] + self.maxs = self.band_maxs[:, None, None] + elif bands == "s1": + self.mins = self.band_mins[:2, None, None] + self.maxs = self.band_maxs[:2, None, None] + else: + self.mins = self.band_mins[2:, None, None] + self.maxs = self.band_maxs[2:, None, None] + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) + sample["image"] = torch.clip( # type: ignore[attr-defined] + sample["image"], min=0.0, max=1.0 + ) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + transforms = Compose([self.preprocess]) + self.train_dataset = BigEarthNet( + self.root_dir, + split="train", + bands=self.bands, + num_classes=self.num_classes, + transforms=transforms, + ) + self.val_dataset = BigEarthNet( + self.root_dir, + split="val", + bands=self.bands, + num_classes=self.num_classes, + transforms=transforms, + ) + self.test_dataset = BigEarthNet( + self.root_dir, + split="test", + bands=self.bands, + num_classes=self.num_classes, + transforms=transforms, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py new file mode 100644 index 00000000000..2f575f048ba --- /dev/null +++ b/torchgeo/datamodules/chesapeake.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Chesapeake Bay High-Resolution Land Cover Project datamodule.""" + +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from pytorch_lightning.core.datamodule import LightningDataModule +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import ChesapeakeCVPR, stack_samples +from ..samplers.batch import RandomBatchGeoSampler +from ..samplers.single import GridGeoSampler + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class ChesapeakeCVPRDataModule(LightningDataModule): + """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. + + Uses the random splits defined per state to partition tiles into train, val, + and test sets. + """ + + def __init__( + self, + root_dir: str, + train_splits: List[str], + val_splits: List[str], + test_splits: List[str], + patches_per_tile: int = 200, + patch_size: int = 256, + batch_size: int = 64, + num_workers: int = 0, + class_set: int = 7, + use_prior_labels: bool = False, + prior_smoothing_constant: float = 1e-4, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset + classes + train_splits: The splits used to train the model, e.g. ["ny-train"] + val_splits: The splits used to validate the model, e.g. ["ny-val"] + test_splits: The splits used to test the model, e.g. ["ny-test"] + patches_per_tile: The number of patches per tile to sample + patch_size: The size of each patch in pixels (test patches will be 1.5 times + this size) + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + class_set: The high-resolution land cover class set to use - 5 or 7 + use_prior_labels: Flag for using a prior over high-resolution classes + instead of the high-resolution labels themselves + prior_smoothing_constant: additive smoothing to add when using prior labels + + Raises: + ValueError: if ``use_prior_labels`` is used with ``class_set==7`` + """ + super().__init__() # type: ignore[no-untyped-call] + for state in train_splits + val_splits + test_splits: + assert state in ChesapeakeCVPR.splits + assert class_set in [5, 7] + if use_prior_labels and class_set != 5: + raise ValueError( + "The pre-generated prior labels are only valid for the 5" + + " class set of labels" + ) + + self.root_dir = root_dir + self.train_splits = train_splits + self.val_splits = val_splits + self.test_splits = test_splits + self.patches_per_tile = patches_per_tile + self.patch_size = patch_size + # This is a rough estimate of how large of a patch we will need to sample in + # EPSG:3857 in order to guarantee a large enough patch in the local CRS. + self.original_patch_size = int(patch_size * 2.0) + self.batch_size = batch_size + self.num_workers = num_workers + self.class_set = class_set + self.use_prior_labels = use_prior_labels + self.prior_smoothing_constant = prior_smoothing_constant + + if self.use_prior_labels: + self.layers = [ + "naip-new", + "prior_from_cooccurrences_101_31_no_osm_no_buildings", + ] + else: + self.layers = ["naip-new", "lc"] + + def pad_to( + self, size: int = 512, image_value: int = 0, mask_value: int = 0 + ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: + """Returns a function to perform a padding transform on a single sample. + + Args: + size: output image size + image_value: value to pad image with + mask_value: value to pad mask with + + Returns: + function to perform padding + """ + + def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + _, height, width = sample["image"].shape + assert height <= size and width <= size + + height_pad = size - height + width_pad = size - width + + # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + # for a description of the format of the padding tuple + sample["image"] = F.pad( + sample["image"], + (0, width_pad, 0, height_pad), + mode="constant", + value=image_value, + ) + sample["mask"] = F.pad( + sample["mask"], + (0, width_pad, 0, height_pad), + mode="constant", + value=mask_value, + ) + return sample + + return pad_inner + + def center_crop( + self, size: int = 512 + ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: + """Returns a function to perform a center crop transform on a single sample. + + Args: + size: output image size + + Returns: + function to perform center crop + """ + + def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + _, height, width = sample["image"].shape + + y1 = (height - size) // 2 + x1 = (width - size) // 2 + sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] + sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] + + return sample + + return center_crop_inner + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Preprocesses a single sample. + + Args: + sample: sample dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 + sample["mask"] = sample["mask"].squeeze() + + if self.use_prior_labels: + sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0) + sample["mask"] = F.normalize( + sample["mask"] + self.prior_smoothing_constant, p=1, dim=0 + ) + else: + if self.class_set == 5: + sample["mask"][sample["mask"] == 5] = 4 + sample["mask"][sample["mask"] == 6] = 4 + sample["mask"] = sample["mask"].long() + + sample["image"] = sample["image"].float() + + del sample["bbox"] + + return sample + + def nodata_check( + self, size: int = 512 + ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: + """Returns a function to check for nodata or mis-sized input. + + Args: + size: output image size + + Returns: + function to check for nodata values + """ + + def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + num_channels, height, width = sample["image"].shape + + if height < size or width < size: + sample["image"] = torch.zeros( # type: ignore[attr-defined] + (num_channels, size, size) + ) + sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined] + + return sample + + return nodata_check_inner + + def prepare_data(self) -> None: + """Confirms that the dataset is downloaded on the local node. + + This method is called once per node, while :func:`setup` is called once per GPU. + """ + ChesapeakeCVPR( + self.root_dir, + splits=self.train_splits, + layers=self.layers, + transforms=None, + download=False, + checksum=False, + ) + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + Args: + stage: stage to set up + """ + train_transforms = Compose( + [ + self.center_crop(self.patch_size), + self.nodata_check(self.patch_size), + self.preprocess, + ] + ) + val_transforms = Compose( + [ + self.center_crop(self.patch_size), + self.nodata_check(self.patch_size), + self.preprocess, + ] + ) + test_transforms = Compose( + [ + self.pad_to(self.original_patch_size, image_value=0, mask_value=0), + self.preprocess, + ] + ) + + self.train_dataset = ChesapeakeCVPR( + self.root_dir, + splits=self.train_splits, + layers=self.layers, + transforms=train_transforms, + download=False, + checksum=False, + ) + self.val_dataset = ChesapeakeCVPR( + self.root_dir, + splits=self.val_splits, + layers=self.layers, + transforms=val_transforms, + download=False, + checksum=False, + ) + self.test_dataset = ChesapeakeCVPR( + self.root_dir, + splits=self.test_splits, + layers=self.layers, + transforms=test_transforms, + download=False, + checksum=False, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + sampler = RandomBatchGeoSampler( + self.train_dataset, + size=self.original_patch_size, + batch_size=self.batch_size, + length=self.patches_per_tile * len(self.train_dataset), + ) + return DataLoader( + self.train_dataset, + batch_sampler=sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + sampler = GridGeoSampler( + self.val_dataset, + size=self.original_patch_size, + stride=self.original_patch_size, + ) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + sampler = GridGeoSampler( + self.test_dataset, + size=self.original_patch_size, + stride=self.original_patch_size, + ) + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py new file mode 100644 index 00000000000..4d6e4a7cdb8 --- /dev/null +++ b/torchgeo/datamodules/cowc.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""COWC datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch import Generator # type: ignore[attr-defined] +from torch.utils.data import DataLoader, random_split + +from ..datasets import COWCCounting + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class COWCCountingDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the COWC Counting dataset.""" + + def __init__( + self, + root_dir: str, + seed: int, + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for COWC Counting based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class + seed: The seed value to use when doing the dataset random_split + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.seed = seed + self.batch_size = batch_size + self.num_workers = num_workers + + def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and target + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 # scale to [0, 1] + sample["label"] = sample["label"].float() + return sample + + def prepare_data(self) -> None: + """Initialize the main ``Dataset`` objects for use in :func:`setup`. + + This includes optionally downloading the dataset. This is done once per node, + while :func:`setup` is done once per GPU. + """ + COWCCounting(self.root_dir, download=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + Args: + stage: stage to set up + """ + train_val_dataset = COWCCounting( + self.root_dir, split="train", transforms=self.custom_transform + ) + self.test_dataset = COWCCounting( + self.root_dir, split="test", transforms=self.custom_transform + ) + self.train_dataset, self.val_dataset = random_split( + train_val_dataset, + [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], + generator=Generator().manual_seed(self.seed), + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py new file mode 100644 index 00000000000..929628e7c37 --- /dev/null +++ b/torchgeo/datamodules/cyclone.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Tropical Cyclone Wind Estimation Competition datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from sklearn.model_selection import GroupShuffleSplit +from torch.utils.data import DataLoader, Subset + +from ..datasets import TropicalCycloneWindEstimation + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class CycloneDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the NASA Cyclone dataset. + + Implements 80/20 train/val splits based on hurricane storm ids. + See :func:`setup` for more details. + """ + + def __init__( + self, + root_dir: str, + seed: int, + batch_size: int = 64, + num_workers: int = 0, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the + TropicalCycloneWindEstimation Datasets classes + seed: The seed value to use when doing the sklearn based GroupShuffleSplit + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + api_key: The RadiantEarth MLHub API key to use if the dataset needs to be + downloaded + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.seed = seed + self.batch_size = batch_size + self.num_workers = num_workers + self.api_key = api_key + + def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and target + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 # scale to [0,1] + sample["image"] = ( + sample["image"].unsqueeze(0).repeat(3, 1, 1) + ) # convert to 3 channel + sample["label"] = torch.as_tensor( # type: ignore[attr-defined] + sample["label"] + ).float() + + return sample + + def prepare_data(self) -> None: + """Initialize the main ``Dataset`` objects for use in :func:`setup`. + + This includes optionally downloading the dataset. This is done once per node, + while :func:`setup` is done once per GPU. + """ + TropicalCycloneWindEstimation( + self.root_dir, + split="train", + transforms=self.custom_transform, + download=self.api_key is not None, + api_key=self.api_key, + ) + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + We split samples between train/val by the ``storm_id`` property. I.e. all + samples with the same ``storm_id`` value will be either in the train or the val + split. This is important to test one type of generalizability -- given a new + storm, can we predict its windspeed. The test set, however, contains *some* + storms from the training set (specifically, the latter parts of the storms) as + well as some novel storms. + + Args: + stage: stage to set up + """ + self.all_train_dataset = TropicalCycloneWindEstimation( + self.root_dir, + split="train", + transforms=self.custom_transform, + download=False, + ) + + self.all_test_dataset = TropicalCycloneWindEstimation( + self.root_dir, + split="test", + transforms=self.custom_transform, + download=False, + ) + + storm_ids = [] + for item in self.all_train_dataset.collection: + storm_id = item["href"].split("/")[0].split("_")[-2] + storm_ids.append(storm_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( + storm_ids, groups=storm_ids + ) + ) + + self.train_dataset = Subset(self.all_train_dataset, train_indices) + self.val_dataset = Subset(self.all_train_dataset, val_indices) + self.test_dataset = Subset( + self.all_test_dataset, range(len(self.all_test_dataset)) + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py new file mode 100644 index 00000000000..5db89a07379 --- /dev/null +++ b/torchgeo/datamodules/etci2021.py @@ -0,0 +1,151 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""ETCI 2021 datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch import Generator # type: ignore[attr-defined] +from torch.utils.data import DataLoader, random_split +from torchvision.transforms import Normalize + +from ..datasets import ETCI2021 + + +class ETCI2021DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the ETCI2021 dataset. + + Splits the existing train split from the dataset into train/val with 80/20 + proportions, then uses the existing val dataset as the test data. + + .. versionadded:: 0.2 + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1] + ) + + def __init__( + self, + root_dir: str, + seed: int = 0, + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for ETCI2021 based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes + seed: The seed value to use when doing the dataset random_split + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.seed = seed + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Notably, moves the given water mask to act as an input layer. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + image = sample["image"] + water_mask = sample["mask"][0].unsqueeze(0) + flood_mask = sample["mask"][1] + flood_mask = (flood_mask > 0).long() + + sample["image"] = torch.cat( # type: ignore[attr-defined] + [image, water_mask], dim=0 + ).float() + sample["image"] /= 255.0 + sample["image"] = self.norm(sample["image"]) + sample["mask"] = flood_mask + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + ETCI2021(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_val_dataset = ETCI2021( + self.root_dir, split="train", transforms=self.preprocess + ) + self.test_dataset = ETCI2021( + self.root_dir, split="val", transforms=self.preprocess + ) + + size_train_val = len(train_val_dataset) + size_train = int(0.8 * size_train_val) + size_val = size_train_val - size_train + + self.train_dataset, self.val_dataset = random_split( + train_val_dataset, + [size_train, size_val], + generator=Generator().manual_seed(self.seed), + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py new file mode 100644 index 00000000000..72708e07019 --- /dev/null +++ b/torchgeo/datamodules/eurosat.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""EuroSAT datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import EuroSAT + + +class EuroSATDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the EuroSAT dataset. + + Uses the train/val/test splits from the dataset. + + .. versionadded:: 0.2 + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [ + 1354.40546513, + 1118.24399958, + 1042.92983953, + 947.62620298, + 1199.47283961, + 1999.79090914, + 2369.22292565, + 2296.82608323, + 732.08340178, + 12.11327804, + 1819.01027855, + 1118.92391149, + 2594.14080798, + ] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [ + 245.71762908, + 333.00778264, + 395.09249139, + 593.75055589, + 566.4170017, + 861.18399006, + 1086.63139075, + 1117.98170791, + 404.91978886, + 4.77584468, + 1002.58768311, + 761.30323499, + 1231.58581042, + ] + ) + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for EuroSAT based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + EuroSAT(self.root_dir) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms) + self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms) + self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py new file mode 100644 index 00000000000..15a8cbfca52 --- /dev/null +++ b/torchgeo/datamodules/fair1m.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""FAIR1M datamodule.""" + +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +import torch +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import FAIR1M +from .utils import dataset_split + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable number of boxes. + + Args: + batch: list of sample dicts return by dataset + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["boxes"] = [sample["boxes"] for sample in batch] + return output + + +class FAIR1MDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the FAIR1M dataset.""" + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for FAIR1M based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = FAIR1M(self.root_dir, transforms=transforms) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py new file mode 100644 index 00000000000..b0f23182a27 --- /dev/null +++ b/torchgeo/datamodules/landcoverai.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LandCover.ai datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from ..datasets import LandCoverAI + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class LandCoverAIDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the LandCover.ai dataset. + + Uses the train/val/test splits from the dataset. + """ + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for LandCover.ai based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 + + sample["image"] = sample["image"].float() + sample["mask"] = sample["mask"].long() + 1 + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + _ = LandCoverAI(self.root_dir, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_transforms = self.preprocess + val_test_transforms = self.preprocess + + self.train_dataset = LandCoverAI( + self.root_dir, split="train", transforms=train_transforms + ) + + self.val_dataset = LandCoverAI( + self.root_dir, split="val", transforms=val_test_transforms + ) + + self.test_dataset = LandCoverAI( + self.root_dir, split="test", transforms=val_test_transforms + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py new file mode 100644 index 00000000000..4aeae5323b6 --- /dev/null +++ b/torchgeo/datamodules/loveda.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LoveDA datamodule.""" + +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from ..datasets import LoveDA + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class LoveDADataModule(pl.LightningDataModule): + """LightningDataModule implementation for the LoveDA dataset. + + Uses the train/val/test splits from the dataset. + """ + + def __init__( + self, + root_dir: str, + scene: List[str], + batch_size: int = 32, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for LoveDA based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to LoveDA Dataset classes + scene: specify whether to load only 'urban', only 'rural' or both + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.scene = scene + self.batch_size = batch_size + self.num_workers = num_workers + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"] / 255.0 + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + _ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_transforms = self.preprocess + val_test_transforms = self.preprocess + + self.train_dataset = LoveDA( + self.root_dir, split="train", scene=self.scene, transforms=train_transforms + ) + + self.val_dataset = LoveDA( + self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms + ) + + self.test_dataset = LoveDA( + self.root_dir, + split="test", + scene=self.scene, + transforms=val_test_transforms, + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py new file mode 100644 index 00000000000..928674dc5bd --- /dev/null +++ b/torchgeo/datamodules/naip.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""National Agriculture Imagery Program (NAIP) datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples +from ..samplers.batch import RandomBatchGeoSampler +from ..samplers.single import GridGeoSampler + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class NAIPChesapeakeDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the NAIP and Chesapeake datasets. + + Uses the train/val/test splits from the dataset. + """ + + # TODO: tune these hyperparams + length = 1000 + stride = 128 + + def __init__( + self, + naip_root_dir: str, + chesapeake_root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + patch_size: int = 256, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. + + Args: + naip_root_dir: directory containing NAIP data + chesapeake_root_dir: directory containing Chesapeake data + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + patch_size: size of patches to sample + """ + super().__init__() # type: ignore[no-untyped-call] + self.naip_root_dir = naip_root_dir + self.chesapeake_root_dir = chesapeake_root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.patch_size = patch_size + + def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the NAIP Dataset. + + Args: + sample: NAIP image dictionary + + Returns: + preprocessed NAIP data + """ + sample["image"] = sample["image"] / 255.0 + sample["image"] = sample["image"].float() + + del sample["bbox"] + + return sample + + def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Chesapeake Dataset. + + Args: + sample: Chesapeake mask dictionary + + Returns: + preprocessed Chesapeake data + """ + sample["mask"] = sample["mask"].long()[0] + + del sample["bbox"] + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: state to set up + """ + # TODO: these transforms will be applied independently, this won't work if we + # add things like random horizontal flip + chesapeake = Chesapeake13( + self.chesapeake_root_dir, transforms=self.chesapeake_transform + ) + naip = NAIP( + self.naip_root_dir, + chesapeake.crs, + chesapeake.res, + transforms=self.naip_transform, + ) + self.dataset = chesapeake & naip + + # TODO: figure out better train/val/test split + roi = self.dataset.bounds + midx = roi.minx + (roi.maxx - roi.minx) / 2 + midy = roi.miny + (roi.maxy - roi.miny) / 2 + train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) + val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) + test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) + + self.train_sampler = RandomBatchGeoSampler( + naip, self.patch_size, self.batch_size, self.length, train_roi + ) + self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi) + self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.dataset, + batch_sampler=self.train_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.val_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.test_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py new file mode 100644 index 00000000000..e6337e9fb6a --- /dev/null +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NASA Marine Debris datamodule.""" + +from typing import Any, Dict, List, Optional + +import pytorch_lightning as pl +import torch +from torch import Tensor +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import NASAMarineDebris +from .utils import dataset_split + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable boxes. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["boxes"] = [sample["boxes"] for sample in batch] + return output + + +class NASAMarineDebrisDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the NASA Marine Debris dataset.""" + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the Dataset class + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + NASAMarineDebris(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = NASAMarineDebris(self.root_dir, transforms=transforms) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py new file mode 100644 index 00000000000..cba602fa263 --- /dev/null +++ b/torchgeo/datamodules/oscd.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""OSCD datamodule.""" + +from typing import Any, Dict, List, Optional, Tuple + +import kornia.augmentation as K +import pytorch_lightning as pl +import torch +from einops import repeat +from torch.utils.data import DataLoader, Dataset +from torch.utils.data._utils.collate import default_collate +from torchvision.transforms import Compose, Normalize + +from ..datasets import OSCD +from .utils import dataset_split + + +class OSCDDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the OSCD dataset. + + Uses the train/test splits from the dataset and further splits + the train split into train/val splits. + + .. versionadded:: 0.2 + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [ + 1583.0741, + 1374.3202, + 1294.1616, + 1325.6158, + 1478.7408, + 1933.0822, + 2166.0608, + 2076.4868, + 2306.0652, + 690.9814, + 16.2360, + 2080.3347, + 1524.6930, + ] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [ + 52.1937, + 83.4168, + 105.6966, + 151.1401, + 147.4615, + 115.9289, + 123.1974, + 114.6483, + 141.4530, + 73.2758, + 4.8368, + 213.4821, + 179.4793, + ] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + train_batch_size: int = 32, + num_workers: int = 0, + val_split_pct: float = 0.2, + patch_size: Tuple[int, int] = (64, 64), + num_patches_per_tile: int = 32, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for OSCD based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the OSCD Dataset classes + bands: "rgb" or "all" + train_batch_size: The batch size used in the train DataLoader + (val_batch_size == test_batch_size == 1) + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + patch_size: Size of random patch from image and mask (height, width) + num_patches_per_tile: number of random patches per sample + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.train_batch_size = train_batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.patch_size = patch_size + self.num_patches_per_tile = num_patches_per_tile + + if bands == "rgb": + self.band_means = self.band_means[[3, 2, 1], None, None] + self.band_stds = self.band_stds[[3, 2, 1], None, None] + else: + self.band_means = self.band_means[:, None, None] + self.band_stds = self.band_stds[:, None, None] + + self.norm = Normalize(self.band_means, self.band_stds) + self.rcrop = K.AugmentationSequential( + K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True + ) + self.padto = K.PadTo((1280, 1280)) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["mask"] = sample["mask"] + sample["image"] = self.norm(sample["image"]) + sample["image"] = torch.flatten( # type: ignore[attr-defined] + sample["image"], 0, 1 + ) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + OSCD(self.root_dir, split="train", bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + + def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: + images, masks = [], [] + for i in range(self.num_patches_per_tile): + mask = repeat(sample["mask"], "h w -> t h w", t=2).float() + image, mask = self.rcrop(sample["image"], mask) + mask = mask.squeeze()[0] + images.append(image.squeeze()) + masks.append(mask.long()) + sample["image"] = torch.stack(images) + sample["mask"] = torch.stack(masks) + return sample + + def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: + sample["image"] = self.padto(sample["image"])[0] + sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] + return sample + + train_transforms = Compose([self.preprocess, n_random_crop]) + # for testing and validation we pad all inputs to a fixed size to avoid issues + # with the upsampling paths in encoder-decoder architectures + test_transforms = Compose([self.preprocess, pad_to]) + + train_dataset = OSCD( + self.root_dir, split="train", bands=self.bands, transforms=train_transforms + ) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + val_dataset = OSCD( + self.root_dir, + split="train", + bands=self.bands, + transforms=test_transforms, + ) + self.train_dataset, self.val_dataset, _ = dataset_split( + train_dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + self.val_dataset.dataset = val_dataset + else: + self.train_dataset = train_dataset + self.val_dataset = train_dataset + + self.test_dataset = OSCD( + self.root_dir, split="test", bands=self.bands, transforms=test_transforms + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + + def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] + batch + ) + r_batch["image"] = torch.flatten( # type: ignore[attr-defined] + r_batch["image"], 0, 1 + ) + r_batch["mask"] = torch.flatten( # type: ignore[attr-defined] + r_batch["mask"], 0, 1 + ) + return r_batch + + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + num_workers=self.num_workers, + collate_fn=collate_wrapper, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + return DataLoader( + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + return DataLoader( + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False + ) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py new file mode 100644 index 00000000000..776d999cb1e --- /dev/null +++ b/torchgeo/datamodules/potsdam.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Potsdam datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import Potsdam2D +from .utils import dataset_split + + +class Potsdam2DDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the Potsdam2D dataset. + + Uses the train/test splits from the dataset. + + .. versionadded:: 0.2 + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for Potsdam2D based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = Potsdam2D(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py new file mode 100644 index 00000000000..844ee0968a9 --- /dev/null +++ b/torchgeo/datamodules/resisc45.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""RESISC45 datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import RESISC45 + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class RESISC45DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the RESISC45 dataset. + + Uses the train/val/test splits from the dataset. + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [0.36801773, 0.38097873, 0.343583] + ) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [0.14540215, 0.13558227, 0.13203649] + ) + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for RESISC45 based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + RESISC45(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms) + self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms) + self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py new file mode 100644 index 00000000000..cfe5900c478 --- /dev/null +++ b/torchgeo/datamodules/sen12ms.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SEN12MS datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from sklearn.model_selection import GroupShuffleSplit +from torch.utils.data import DataLoader, Subset + +from ..datasets import SEN12MS + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class SEN12MSDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the SEN12MS dataset. + + Implements 80/20 geographic train/val splits and uses the test split from the + classification dataset definitions. See :func:`setup` for more details. + + Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See + https://arxiv.org/abs/2002.08254. + """ + + #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader + #: here https://github.com/lukasliebel/dfc2020_baseline. + DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined] + [ + 0, # maps 0s to 0 + 1, # maps 1s to 1 + 1, # maps 2s to 1 + 1, # ... + 1, + 1, + 2, + 2, + 3, + 3, + 4, + 5, + 6, + 7, + 6, + 8, + 9, + 10, + ] + ) + + def __init__( + self, + root_dir: str, + seed: int, + band_set: str = "all", + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for SEN12MS based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes + seed: The seed value to use when doing the sklearn based ShuffleSplit + band_set: The subset of S1/S2 bands to use. Options are: "all", + "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: + B2, B3, B4, B8, B11, and B12. + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + assert band_set in SEN12MS.BAND_SETS.keys() + + self.root_dir = root_dir + self.seed = seed + self.band_set = band_set + self.band_indices = SEN12MS.BAND_SETS[band_set] + self.batch_size = batch_size + self.num_workers = num_workers + + def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image and mask + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + + if self.band_set == "all": + sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 + sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 + elif self.band_set == "s1": + sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 + else: + sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 + + sample["mask"] = sample["mask"][0, :, :].long() + sample["mask"] = torch.take( # type: ignore[attr-defined] + self.DFC2020_CLASS_MAPPING, sample["mask"] + ) + + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Create the train/val/test splits based on the original Dataset objects. + + The splits should be done here vs. in :func:`__init__` per the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + + We split samples between train and val geographically with proportions of 80/20. + This mimics the geographic test set split. + + Args: + stage: stage to set up + """ + season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} + + self.all_train_dataset = SEN12MS( + self.root_dir, + split="train", + bands=self.band_indices, + transforms=self.custom_transform, + checksum=False, + ) + + self.all_test_dataset = SEN12MS( + self.root_dir, + split="test", + bands=self.band_indices, + transforms=self.custom_transform, + checksum=False, + ) + + # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" + # This patch will belong to the scene that is uniquelly identified by its + # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply + # give each season a large number and representing a `unique_scene_id` as + # `season_id + scene_id`. + scenes = [] + for scene_fn in self.all_train_dataset.ids: + parts = scene_fn.split("_") + season_id = season_to_int[parts[1]] + scene_id = int(parts[3]) + scenes.append(season_id + scene_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( + scenes, groups=scenes + ) + ) + + self.train_dataset = Subset(self.all_train_dataset, train_indices) + self.val_dataset = Subset(self.all_train_dataset, val_indices) + self.test_dataset = Subset( + self.all_test_dataset, range(len(self.all_test_dataset)) + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py new file mode 100644 index 00000000000..9f072edbf43 --- /dev/null +++ b/torchgeo/datamodules/so2sat.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""So2Sat datamodule.""" + +from typing import Any, Dict, Optional, cast + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import So2Sat + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class So2SatDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the So2Sat dataset. + + Uses the train/val/test splits from the dataset. + """ + + band_means = torch.tensor( # type: ignore[attr-defined] + [ + -3.591224256609313e-05, + -7.658561276843396e-06, + 5.9373857475971184e-05, + 2.5166231537121083e-05, + 0.04420110659759328, + 0.25761027084996196, + 0.0007556743372573258, + 0.0013503466830024448, + 0.12375696117681859, + 0.1092774636368323, + 0.1010855203267882, + 0.1142398616114001, + 0.1592656692023089, + 0.18147236008771792, + 0.1745740312291377, + 0.19501607349635292, + 0.15428468872076637, + 0.10905050699570007, + ] + ).reshape(18, 1, 1) + + band_stds = torch.tensor( # type: ignore[attr-defined] + [ + 0.17555201137417686, + 0.17556463274968204, + 0.45998793417834255, + 0.455988755730148, + 2.8559909213125763, + 8.324800606439833, + 2.4498757382563103, + 1.4647352984509094, + 0.03958795985905458, + 0.047778262752410296, + 0.06636616706371974, + 0.06358874912497474, + 0.07744387147984592, + 0.09101635085921553, + 0.09218466562387101, + 0.10164581233948201, + 0.09991773043519253, + 0.08780632509122865, + ] + ).reshape(18, 1, 1) + + # this reorders the bands to put S2 RGB first, then remainder of S2, then S1 + reindex_to_rgb_first = [ + 10, + 9, + 8, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + # 0, + # 1, + # 2, + # 3, + # 4, + # 5, + # 6, + # 7, + ] + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + bands: str = "rgb", + unsupervised_mode: bool = False, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for So2Sat based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + bands: Either "rgb" or "s2" + unsupervised_mode: Makes the train dataloader return imagery from the train, + val, and test sets + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.bands = bands + self.unsupervised_mode = unsupervised_mode + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image + + Returns: + preprocessed sample + """ + # sample["image"] = (sample["image"] - self.band_means) / self.band_stds + sample["image"] = sample["image"].float() + sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] + + if self.bands == "rgb": + sample["image"] = sample["image"][:3, :, :] + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + So2Sat(self.root_dir, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + train_transforms = Compose([self.preprocess]) + val_test_transforms = self.preprocess + + if not self.unsupervised_mode: + + self.train_dataset = So2Sat( + self.root_dir, split="train", transforms=train_transforms + ) + + self.val_dataset = So2Sat( + self.root_dir, split="validation", transforms=val_test_transforms + ) + + self.test_dataset = So2Sat( + self.root_dir, split="test", transforms=val_test_transforms + ) + + else: + + temp_train = So2Sat( + self.root_dir, split="train", transforms=train_transforms + ) + + self.val_dataset = So2Sat( + self.root_dir, split="validation", transforms=train_transforms + ) + + self.test_dataset = So2Sat( + self.root_dir, split="test", transforms=train_transforms + ) + + self.train_dataset = cast( + So2Sat, temp_train + self.val_dataset + self.test_dataset + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py new file mode 100644 index 00000000000..69cd9773384 --- /dev/null +++ b/torchgeo/datamodules/ucmerced.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""UC Merced datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +import torchvision +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import UCMerced + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class UCMercedDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the UC Merced dataset. + + Uses random train/val/test splits. + """ + + band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined] + + band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined] + + def __init__( + self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a LightningDataModule for UCMerced based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + c, h, w = sample["image"].shape + if h != 256 or w != 256: + sample["image"] = torchvision.transforms.functional.resize( + sample["image"], size=(256, 256) + ) + sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + UCMerced(self.root_dir, download=False, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms) + self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms) + self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py new file mode 100644 index 00000000000..ff1f571c2b6 --- /dev/null +++ b/torchgeo/datamodules/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Common datamodule utilities.""" + +from typing import Any, List, Optional + +from torch.utils.data import Dataset, Subset, random_split + + +def dataset_split( + dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None +) -> List[Subset[Any]]: + """Split a torch Dataset into train/val/test sets. + + If ``test_pct`` is not set then only train and validation splits are returned. + + Args: + dataset: dataset to be split into train/val or train/val/test subsets + val_pct: percentage of samples to be in validation set + test_pct: (Optional) percentage of samples to be in test set + Returns: + a list of the subset datasets. Either [train, val] or [train, val, test] + """ + if test_pct is None: + val_length = int(len(dataset) * val_pct) + train_length = len(dataset) - val_length + return random_split(dataset, [train_length, val_length]) + else: + val_length = int(len(dataset) * val_pct) + test_length = int(len(dataset) * test_pct) + train_length = len(dataset) - (val_length + test_length) + return random_split(dataset, [train_length, val_length, test_length]) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py new file mode 100644 index 00000000000..cced8a3ff27 --- /dev/null +++ b/torchgeo/datamodules/vaihingen.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Vaihingen datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import Vaihingen2D +from .utils import dataset_split + + +class Vaihingen2DDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the Vaihingen2D dataset. + + Uses the train/test splits from the dataset. + + .. versionadded:: 0.2 + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. + + Args: + root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py new file mode 100644 index 00000000000..2548c24851e --- /dev/null +++ b/torchgeo/datamodules/xview.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""xView2 datamodule.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import Compose + +from ..datasets import XView2 +from .utils import dataset_split + + +class XView2DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the xView2 dataset. + + Uses the train/val/test splits from the dataset. + + .. versionadded:: 0.2 + """ + + def __init__( + self, + root_dir: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for xView2 based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the xView2 Dataset classes + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + transforms = Compose([self.preprocess]) + + dataset = XView2(self.root_dir, "train", transforms=transforms) + + self.train_dataset: Dataset[Any] + self.val_dataset: Dataset[Any] + + if self.val_split_pct > 0.0: + self.train_dataset, self.val_dataset, _ = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=0.0 + ) + else: + self.train_dataset = dataset + self.val_dataset = dataset + + self.test_dataset = XView2(self.root_dir, "test", transforms=transforms) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0f5e24b38cf..7e3cf7811c2 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -5,7 +5,7 @@ from .advance import ADVANCE from .benin_cashews import BeninSmallHolderCashews -from .bigearthnet import BigEarthNet, BigEarthNetDataModule +from .bigearthnet import BigEarthNet from .cbf import CanadianBuildingFootprints from .cdl import CDL from .chesapeake import ( @@ -13,7 +13,6 @@ Chesapeake7, Chesapeake13, ChesapeakeCVPR, - ChesapeakeCVPRDataModule, ChesapeakeDC, ChesapeakeDE, ChesapeakeMD, @@ -22,12 +21,12 @@ ChesapeakeVA, ChesapeakeWV, ) -from .cowc import COWC, COWCCounting, COWCCountingDataModule, COWCDetection +from .cowc import COWC, COWCCounting, COWCDetection from .cv4a_kenya_crop_type import CV4AKenyaCropType -from .cyclone import CycloneDataModule, TropicalCycloneWindEstimation -from .etci2021 import ETCI2021, ETCI2021DataModule -from .eurosat import EuroSAT, EuroSATDataModule -from .fair1m import FAIR1M, FAIR1MDataModule +from .cyclone import TropicalCycloneWindEstimation +from .etci2021 import ETCI2021 +from .eurosat import EuroSAT +from .fair1m import FAIR1M from .geo import ( GeoDataset, IntersectionDataset, @@ -39,7 +38,7 @@ ) from .gid15 import GID15 from .idtrees import IDTReeS -from .landcoverai import LandCoverAI, LandCoverAIDataModule +from .landcoverai import LandCoverAI from .landsat import ( Landsat, Landsat1, @@ -54,23 +53,23 @@ Landsat9, ) from .levircd import LEVIRCDPlus -from .loveda import LoveDA, LoveDADataModule -from .naip import NAIP, NAIPChesapeakeDataModule -from .nasa_marine_debris import NASAMarineDebris, NASAMarineDebrisDataModule +from .loveda import LoveDA +from .naip import NAIP +from .nasa_marine_debris import NASAMarineDebris from .nwpu import VHR10 -from .oscd import OSCD, OSCDDataModule +from .oscd import OSCD from .patternnet import PatternNet -from .potsdam import Potsdam2D, Potsdam2DDataModule -from .resisc45 import RESISC45, RESISC45DataModule +from .potsdam import Potsdam2D +from .resisc45 import RESISC45 from .seco import SeasonalContrastS2 -from .sen12ms import SEN12MS, SEN12MSDataModule +from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 -from .so2sat import So2Sat, So2SatDataModule +from .so2sat import So2Sat from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 -from .ucmerced import UCMerced, UCMercedDataModule +from .ucmerced import UCMerced from .utils import BoundingBox, concat_samples, merge_samples, stack_samples -from .vaihingen import Vaihingen2D, Vaihingen2DDataModule -from .xview import XView2, XView2DataModule +from .vaihingen import Vaihingen2D +from .xview import XView2 from .zuericrop import ZueriCrop __all__ = ( @@ -88,7 +87,6 @@ "ChesapeakeVA", "ChesapeakeWV", "ChesapeakeCVPR", - "ChesapeakeCVPRDataModule", "Landsat", "Landsat1", "Landsat2", @@ -101,46 +99,32 @@ "Landsat8", "Landsat9", "NAIP", - "NAIPChesapeakeDataModule", "Sentinel", "Sentinel2", # VisionDataset "ADVANCE", "BeninSmallHolderCashews", "BigEarthNet", - "BigEarthNetDataModule", "COWC", "COWCCounting", "COWCDetection", - "COWCCountingDataModule", "CV4AKenyaCropType", "ETCI2021", - "ETCI2021DataModule", "EuroSAT", - "EuroSATDataModule", "FAIR1M", - "FAIR1MDataModule", "GID15", "IDTReeS", "LandCoverAI", - "LandCoverAIDataModule", "LEVIRCDPlus", "LoveDA", - "LoveDADataModule", "NASAMarineDebris", - "NASAMarineDebrisDataModule", "OSCD", - "OSCDDataModule", "PatternNet", "Potsdam2D", - "Potsdam2DDataModule", "RESISC45", - "RESISC45DataModule", "SeasonalContrastS2", "SEN12MS", - "SEN12MSDataModule", "So2Sat", - "So2SatDataModule", "SpaceNet", "SpaceNet1", "SpaceNet2", @@ -148,14 +132,10 @@ "SpaceNet5", "SpaceNet7", "TropicalCycloneWindEstimation", - "CycloneDataModule", "UCMerced", - "UCMercedDataModule", "Vaihingen2D", - "Vaihingen2DDataModule", "VHR10", "XView2", - "XView2DataModule", "ZueriCrop", # Base classes "GeoDataset", diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index e5eee11b97b..c22819b0b7d 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -5,8 +5,9 @@ import glob import os -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, cast +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -229,3 +230,43 @@ def _download(self) -> None: download_and_extract_archive( url, self.root, filename=filename, md5=md5 if self.checksum else None ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + label = cast(int, sample["label"].item()) + label_class = self.classes[label] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(int, sample["prediction"].item()) + prediction_class = self.classes[prediction] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {label_class}" + if showing_predictions: + title += f"\nPrediction: {prediction_class}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index fffcc7f67c2..3d968de12db 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -8,6 +8,7 @@ from functools import lru_cache from typing import Callable, Dict, Optional, Tuple +import matplotlib.pyplot as plt import numpy as np import rasterio import rasterio.features @@ -135,7 +136,7 @@ class BeninSmallHolderCashews(VisionDataset): "2020_10_30", ) - band_names = ( + ALL_BANDS = ( "B01", "B02", "B03", @@ -150,16 +151,17 @@ class BeninSmallHolderCashews(VisionDataset): "B12", "CLD", ) - - class_names = { - 0: "No data", - 1: "Well-managed planatation", - 2: "Poorly-managed planatation", - 3: "Non-planatation", - 4: "Residential", - 5: "Background", - 6: "Uncertain", - } + RGB_BANDS = ("B04", "B03", "B02") + + classes = [ + "No data", + "Well-managed planatation", + "Poorly-managed planatation", + "Non-planatation", + "Residential", + "Background", + "Uncertain", + ] # Same for all tiles tile_height = 1186 @@ -170,7 +172,7 @@ def __init__( root: str = "data", chip_size: int = 256, stride: int = 128, - bands: Tuple[str, ...] = band_names, + bands: Tuple[str, ...] = ALL_BANDS, transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, @@ -273,11 +275,11 @@ def _validate_bands(self, bands: Tuple[str, ...]) -> None: """ assert isinstance(bands, tuple), "The list of bands must be a tuple" for band in bands: - if band not in self.band_names: + if band not in self.ALL_BANDS: raise ValueError(f"'{band}' is an invalid band name.") @lru_cache(maxsize=128) - def _load_all_imagery(self, bands: Tuple[str, ...] = band_names) -> Tensor: + def _load_all_imagery(self, bands: Tuple[str, ...] = ALL_BANDS) -> Tensor: """Load all the imagery (across time) for the dataset. Optionally allows for subsetting of the bands that are loaded. @@ -410,3 +412,68 @@ def _download(self, api_key: Optional[str] = None) -> None: target_archive_path = os.path.join(self.root, self.target_meta["filename"]) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + time_step: int = 0, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + time_step: time step at which to access image, beginning with 0 + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + ValueError: if the RGB bands are not included in ``self.bands`` + + .. versionadded:: 0.2 + """ + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + num_time_points = sample["image"].shape[0] + assert time_step < num_time_points + + image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3) + image = np.clip(image / 3000, 0, 1) + mask = sample["mask"].numpy() + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4)) + + axs[0].imshow(image) + axs[0].axis("off") + if show_titles: + axs[0].set_title(f"t={time_step}") + + axs[1].imshow(mask, vmin=0, vmax=6, interpolation="none") + axs[1].axis("off") + if show_titles: + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow(predictions, vmin=0, vmax=6, interpolation="none") + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 409e0b82230..2ef2f8285ad 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -6,24 +6,18 @@ import glob import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import rasterio import torch from rasterio.enums import Resampling from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from .geo import VisionDataset from .utils import download_url, extract_archive, sort_sentinel2_bands -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class BigEarthNet(VisionDataset): """BigEarthNet dataset. @@ -125,74 +119,77 @@ class BigEarthNet(VisionDataset): """ - classes_43 = [ - "Agro-forestry areas", - "Airports", - "Annual crops associated with permanent crops", - "Bare rock", - "Beaches, dunes, sands", - "Broad-leaved forest", - "Burnt areas", - "Coastal lagoons", - "Complex cultivation patterns", - "Coniferous forest", - "Construction sites", - "Continuous urban fabric", - "Discontinuous urban fabric", - "Dump sites", - "Estuaries", - "Fruit trees and berry plantations", - "Green urban areas", - "Industrial or commercial units", - "Inland marshes", - "Intertidal flats", - "Land principally occupied by agriculture, with significant areas of " - "natural vegetation", - "Mineral extraction sites", - "Mixed forest", - "Moors and heathland", - "Natural grassland", - "Non-irrigated arable land", - "Olive groves", - "Pastures", - "Peatbogs", - "Permanently irrigated land", - "Port areas", - "Rice fields", - "Road and rail networks and associated land", - "Salines", - "Salt marshes", - "Sclerophyllous vegetation", - "Sea and ocean", - "Sparsely vegetated areas", - "Sport and leisure facilities", - "Transitional woodland/shrub", - "Vineyards", - "Water bodies", - "Water courses", - ] - classes_19 = [ - "Urban fabric", - "Industrial or commercial units", - "Arable land", - "Permanent crops", - "Pastures", - "Complex cultivation patterns", - "Land principally occupied by agriculture, with significant areas of natural " - "vegetation", - "Agro-forestry areas", - "Broad-leaved forest", - "Coniferous forest", - "Mixed forest", - "Natural grassland and sparsely vegetated areas", - "Moors, heathland and sclerophyllous vegetation", - "Transitional woodland, shrub", - "Beaches, dunes, sands", - "Inland wetlands", - "Coastal wetlands", - "Inland waters", - "Marine waters", - ] + class_sets = { + 19: [ + "Urban fabric", + "Industrial or commercial units", + "Arable land", + "Permanent crops", + "Pastures", + "Complex cultivation patterns", + "Land principally occupied by agriculture, with significant areas of" + " natural vegetation", + "Agro-forestry areas", + "Broad-leaved forest", + "Coniferous forest", + "Mixed forest", + "Natural grassland and sparsely vegetated areas", + "Moors, heathland and sclerophyllous vegetation", + "Transitional woodland, shrub", + "Beaches, dunes, sands", + "Inland wetlands", + "Coastal wetlands", + "Inland waters", + "Marine waters", + ], + 43: [ + "Agro-forestry areas", + "Airports", + "Annual crops associated with permanent crops", + "Bare rock", + "Beaches, dunes, sands", + "Broad-leaved forest", + "Burnt areas", + "Coastal lagoons", + "Complex cultivation patterns", + "Coniferous forest", + "Construction sites", + "Continuous urban fabric", + "Discontinuous urban fabric", + "Dump sites", + "Estuaries", + "Fruit trees and berry plantations", + "Green urban areas", + "Industrial or commercial units", + "Inland marshes", + "Intertidal flats", + "Land principally occupied by agriculture, with significant areas of" + " natural vegetation", + "Mineral extraction sites", + "Mixed forest", + "Moors and heathland", + "Natural grassland", + "Non-irrigated arable land", + "Olive groves", + "Pastures", + "Peatbogs", + "Permanently irrigated land", + "Port areas", + "Rice fields", + "Road and rail networks and associated land", + "Salines", + "Salt marshes", + "Sclerophyllous vegetation", + "Sea and ocean", + "Sparsely vegetated areas", + "Sport and leisure facilities", + "Transitional woodland/shrub", + "Vineyards", + "Water bodies", + "Water courses", + ], + } + label_converter = { 0: 0, 1: 0, @@ -227,6 +224,7 @@ class BigEarthNet(VisionDataset): 41: 18, 42: 18, } + splits_metadata = { "train": { "url": "https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/master/splits/train.csv?inline=false", # noqa: E501 @@ -292,7 +290,7 @@ def __init__( self.transforms = transforms self.download = download self.checksum = checksum - self.class2idx = {c: i for i, c in enumerate(self.classes_43)} + self.class2idx = {c: i for i, c in enumerate(self.class_sets[43])} self._verify() self.folders = self._load_folders() @@ -512,163 +510,70 @@ def _extract(self, filepath: str) -> None: if not filepath.endswith(".csv"): extract_archive(filepath) + def _onehot_labels_to_names( + self, label_mask: "np.typing.NDArray[np.bool_]" + ) -> List[str]: + """Gets a list of class names given a label mask. -class BigEarthNetDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the BigEarthNet dataset. - - Uses the train/val/test splits from the dataset. - """ + Args: + label_mask: a boolean mask corresponding to a set of labels or predictions - # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) - # min/max band statistics computed on 100k random samples - band_mins_raw = torch.tensor( # type: ignore[attr-defined] - [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] - ) - band_maxs_raw = torch.tensor( # type: ignore[attr-defined] - [ - 31.0, - 35.0, - 18556.0, - 20528.0, - 18976.0, - 17874.0, - 16611.0, - 16512.0, - 16394.0, - 16672.0, - 16141.0, - 16097.0, - 15336.0, - 15203.0, - ] - ) - - # min/max band statistics computed by percentile clipping the - # above to samples to [2, 98] - band_mins = torch.tensor( # type: ignore[attr-defined] - [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - ) - band_maxs = torch.tensor( # type: ignore[attr-defined] - [ - 6.0, - 16.0, - 9859.0, - 12872.0, - 13163.0, - 14445.0, - 12477.0, - 12563.0, - 12289.0, - 15596.0, - 12183.0, - 9458.0, - 5897.0, - 5544.0, - ] - ) + Returns + a list of class names corresponding to the input mask + """ + labels = [] + for i, mask in enumerate(label_mask): + if mask: + labels.append(self.class_sets[self.num_classes][i]) + return labels - def __init__( + def plot( self, - root_dir: str, - bands: str = "all", - num_classes: int = 19, - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. Args: - root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes - bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} - num_classes: number of classes to load in target. one of {19, 43} - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.bands = bands - self.num_classes = num_classes - self.batch_size = batch_size - self.num_workers = num_workers - - if bands == "all": - self.mins = self.band_mins[:, None, None] - self.maxs = self.band_maxs[:, None, None] - elif bands == "s1": - self.mins = self.band_mins[:2, None, None] - self.maxs = self.band_maxs[:2, None, None] - else: - self.mins = self.band_mins[2:, None, None] - self.maxs = self.band_maxs[2:, None, None] - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) - sample["image"] = torch.clip( # type: ignore[attr-defined] - sample["image"], min=0.0, max=1.0 - ) - return sample + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False) + Returns: + a matplotlib Figure with the rendered sample - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + Raises: + ValueError: if ``self.bands`` is "s1" - This method is called once per GPU per run. + .. versionadded:: 0.2 """ - transforms = Compose([self.preprocess]) - self.train_dataset = BigEarthNet( - self.root_dir, - split="train", - bands=self.bands, - num_classes=self.num_classes, - transforms=transforms, - ) - self.val_dataset = BigEarthNet( - self.root_dir, - split="val", - bands=self.bands, - num_classes=self.num_classes, - transforms=transforms, - ) - self.test_dataset = BigEarthNet( - self.root_dir, - split="test", - bands=self.bands, - num_classes=self.num_classes, - transforms=transforms, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + if self.bands == "s2": + image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3) + image = np.clip(image / 2000, 0, 1) + elif self.bands == "all": + image = np.rollaxis(sample["image"][[5, 4, 3]].numpy(), 0, 3) + image = np.clip(image / 2000, 0, 1) + elif self.bands == "s1": + image = sample["image"][0].numpy() + + label_mask = sample["label"].numpy().astype(np.bool_) + labels = self._onehot_labels_to_names(label_mask) + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_mask = sample["prediction"].numpy().astype(np.bool_) + predictions = self._onehot_labels_to_names(prediction_mask) + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Labels: {', '.join(labels)}" + if showing_predictions: + title += f"\nPredictions: {', '.join(predictions)}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 1709b1a9e13..4c46b9fed91 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -16,21 +16,10 @@ import shapely.geometry import shapely.ops import torch -import torch.nn.functional as F -from pytorch_lightning.core.datamodule import LightningDataModule from rasterio.crs import CRS -from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler from .geo import GeoDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive, stack_samples - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" +from .utils import BoundingBox, download_url, extract_archive class Chesapeake(RasterDataset, abc.ABC): @@ -299,14 +288,30 @@ class ChesapeakeCVPR(GeoDataset): This dataset was organized to accompany the 2019 CVPR paper, "Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data". + The paper "Resolving label uncertainty with implicit generative models" added an + additional layer of data to this dataset containing a prior over the Chesapeake Bay + land cover classes generated from the NLCD land cover labels. For more information + about this layer see `the dataset documentation + `_. + If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/cvpr.2019.01301 """ - url = "https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip" # noqa: E501 - filename = "cvpr_chesapeake_landcover.zip" - md5 = "1225ccbb9590e9396875f221e5031514" + subdatasets = ["base", "prior_extension"] + urls = { + "base": "https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip", # noqa: E501 + "prior_extension": "https://zenodo.org/record/5652512/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1", # noqa: E501 + } + filenames = { + "base": "cvpr_chesapeake_landcover.zip", + "prior_extension": "cvpr_chesapeake_landcover_prior_extension.zip", + } + md5s = { + "base": "1225ccbb9590e9396875f221e5031514", + "prior_extension": "8f43ec30e155274dd652e157c48d2598", + } crs = CRS.from_epsg(3857) res = 1 @@ -319,6 +324,7 @@ class ChesapeakeCVPR(GeoDataset): "nlcd", "lc", "buildings", + "prior_from_cooccurrences_101_31_no_osm_no_buildings", ] states = ["de", "md", "va", "wv", "pa", "ny"] splits = ( @@ -327,6 +333,7 @@ class ChesapeakeCVPR(GeoDataset): + [f"{state}-test" for state in states] ) + # these are used to check the integrity of the dataset files = [ "de_1m_2013_extended-debuffered-test_tiles", "de_1m_2013_extended-debuffered-train_tiles", @@ -346,6 +353,14 @@ class ChesapeakeCVPR(GeoDataset): "wv_1m_2014_extended-debuffered-test_tiles", "wv_1m_2014_extended-debuffered-train_tiles", "wv_1m_2014_extended-debuffered-val_tiles", + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif", + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif", # noqa: E501 + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif", # noqa: E501 + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif", + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif", + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif", + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif", + "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", # noqa: E501 "spatial_index.geojson", ] @@ -376,7 +391,8 @@ def __init__( splits: a list of strings in the format "{state}-{train,val,test}" indicating the subset of data to use, for example "ny-train" layers: a list containing a subset of "naip-new", "naip-old", "lc", "nlcd", - "landsat-leaf-on", "landsat-leaf-off", "buildings" indicating which + "landsat-leaf-on", "landsat-leaf-off", "buildings", or + "prior_from_cooccurrences_101_31_no_osm_no_buildings" indicating which layers to load transforms: a function/transform that takes an input sample and returns a transformed version @@ -410,6 +426,12 @@ def __init__( box = shapely.geometry.shape(row["geometry"]) minx, miny, maxx, maxy = box.bounds coords = (minx, maxx, miny, maxy, mint, maxt) + + prior_fn = row["properties"]["lc"].replace( + "lc.tif", + "prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", + ) + self.index.insert( i, coords, @@ -421,6 +443,7 @@ def __init__( "lc": row["properties"]["lc"], "nlcd": row["properties"]["nlcd"], "buildings": row["properties"]["buildings"], + "prior_from_cooccurrences_101_31_no_osm_no_buildings": prior_fn, # noqa: E501 }, ) @@ -478,7 +501,12 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: "landsat-leaf-off", ]: sample["image"].append(data) - elif layer in ["lc", "nlcd", "buildings"]: + elif layer in [ + "lc", + "nlcd", + "buildings", + "prior_from_cooccurrences_101_31_no_osm_no_buildings", + ]: sample["mask"].append(data) else: raise IndexError(f"query: {query} spans multiple tiles which is not valid") @@ -514,7 +542,12 @@ def exists(filename: str) -> bool: return # Check if the zip files have already been downloaded - if os.path.exists(os.path.join(self.root, self.filename)): + if all( + [ + os.path.exists(os.path.join(self.root, self.filenames[subdataset])) + for subdataset in self.subdatasets + ] + ): self._extract() return @@ -532,299 +565,15 @@ def exists(filename: str) -> bool: def _download(self) -> None: """Download the dataset.""" - download_url(self.url, self.root, filename=self.filename, md5=self.md5) + for subdataset in self.subdatasets: + download_url( + self.urls[subdataset], + self.root, + filename=self.filenames[subdataset], + md5=self.md5s[subdataset], + ) def _extract(self) -> None: """Extract the dataset.""" - extract_archive(os.path.join(self.root, self.filename)) - - -class ChesapeakeCVPRDataModule(LightningDataModule): - """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. - - Uses the random splits defined per state to partition tiles into train, val, - and test sets. - """ - - def __init__( - self, - root_dir: str, - train_splits: List[str], - val_splits: List[str], - test_splits: List[str], - patches_per_tile: int = 200, - patch_size: int = 256, - batch_size: int = 64, - num_workers: int = 0, - class_set: int = 7, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset - classes - train_splits: The splits used to train the model, e.g. ["ny-train"] - val_splits: The splits used to validate the model, e.g. ["ny-val"] - test_splits: The splits used to test the model, e.g. ["ny-test"] - patches_per_tile: The number of patches per tile to sample - patch_size: The size of each patch in pixels (test patches will be 1.5 times - this size) - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - class_set: The high-resolution land cover class set to use - 5 or 7 - """ - super().__init__() # type: ignore[no-untyped-call] - for state in train_splits + val_splits + test_splits: - assert state in ChesapeakeCVPR.splits - assert class_set in [5, 7] - - self.root_dir = root_dir - self.train_splits = train_splits - self.val_splits = val_splits - self.test_splits = test_splits - self.layers = ["naip-new", "lc"] - self.patches_per_tile = patches_per_tile - self.patch_size = patch_size - # This is a rough estimate of how large of a patch we will need to sample in - # EPSG:3857 in order to guarantee a large enough patch in the local CRS. - self.original_patch_size = int(patch_size * 2.0) - self.batch_size = batch_size - self.num_workers = num_workers - self.class_set = class_set - - def pad_to( - self, size: int = 512, image_value: int = 0, mask_value: int = 0 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a padding transform on a single sample. - - Args: - size: output image size - image_value: value to pad image with - mask_value: value to pad mask with - - Returns: - function to perform padding - """ - - def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - assert height <= size and width <= size - - height_pad = size - height - width_pad = size - width - - # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - # for a description of the format of the padding tuple - sample["image"] = F.pad( - sample["image"], - (0, width_pad, 0, height_pad), - mode="constant", - value=image_value, - ) - sample["mask"] = F.pad( - sample["mask"], - (0, width_pad, 0, height_pad), - mode="constant", - value=mask_value, - ) - return sample - - return pad_inner - - def center_crop( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a center crop transform on a single sample. - - Args: - size: output image size - - Returns: - function to perform center crop - """ - - def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - - y1 = (height - size) // 2 - x1 = (width - size) // 2 - sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] - sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] - - return sample - - return center_crop_inner - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Preprocesses a single sample. - - Args: - sample: sample dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 - sample["mask"] = sample["mask"] - sample["mask"] = sample["mask"].squeeze() - - if self.class_set == 5: - sample["mask"][sample["mask"] == 5] = 4 - sample["mask"][sample["mask"] == 6] = 4 - - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"].long() - - return sample - - def nodata_check( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to check for nodata or mis-sized input. - - Args: - size: output image size - - Returns: - function to check for nodata values - """ - - def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - num_channels, height, width = sample["image"].shape - - if height < size or width < size: - sample["image"] = torch.zeros( # type: ignore[attr-defined] - (num_channels, size, size) - ) - sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined] - - return sample - - return nodata_check_inner - - def prepare_data(self) -> None: - """Confirms that the dataset is downloaded on the local node. - - This method is called once per node, while :func:`setup` is called once per GPU. - """ - ChesapeakeCVPR( - self.root_dir, - splits=self.train_splits, - layers=self.layers, - transforms=None, - download=False, - checksum=False, - ) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - Args: - stage: stage to set up - """ - train_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - ] - ) - val_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - ] - ) - test_transforms = Compose( - [ - self.pad_to(self.original_patch_size, image_value=0, mask_value=0), - self.preprocess, - ] - ) - - self.train_dataset = ChesapeakeCVPR( - self.root_dir, - splits=self.train_splits, - layers=self.layers, - transforms=train_transforms, - download=False, - checksum=False, - ) - self.val_dataset = ChesapeakeCVPR( - self.root_dir, - splits=self.val_splits, - layers=self.layers, - transforms=val_transforms, - download=False, - checksum=False, - ) - self.test_dataset = ChesapeakeCVPR( - self.root_dir, - splits=self.test_splits, - layers=self.layers, - transforms=test_transforms, - download=False, - checksum=False, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - sampler = RandomBatchGeoSampler( - self.train_dataset, - size=self.original_patch_size, - batch_size=self.batch_size, - length=self.patches_per_tile * len(self.train_dataset), - ) - return DataLoader( - self.train_dataset, - batch_sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - sampler = GridGeoSampler( - self.val_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - sampler = GridGeoSampler( - self.test_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) + for subdataset in self.subdatasets: + extract_archive(os.path.join(self.root, self.filenames[subdataset])) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 35bbdc54be6..4efec967329 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -6,22 +6,17 @@ import abc import csv import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, cast +import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image -from torch import Generator, Tensor # type: ignore[attr-defined] -from torch.utils.data import DataLoader, random_split +from torch import Tensor from .geo import VisionDataset from .utils import check_integrity, download_and_extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class COWC(VisionDataset, abc.ABC): """Abstract base class for the COWC dataset. @@ -196,6 +191,48 @@ def _download(self) -> None: md5=md5 if self.checksum else None, ) + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image = sample["image"] + label = cast(str, sample["label"].item()) + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = cast(str, sample["prediction"].item()) + else: + prediction = None + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + ax.imshow(image.permute(1, 2, 0)) + ax.axis("off") + + if show_titles: + title = f"Label: {label}" + if prediction is not None: + title += f"\nPrediction: {prediction}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig + class COWCCounting(COWC): """COWC Dataset for car counting.""" @@ -268,110 +305,3 @@ class COWCDetection(COWC): # 4. Unknown # # May need new abstract base class. Will need subclasses for different patch sizes. - - -class COWCCountingDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the COWC Counting dataset.""" - - def __init__( - self, - root_dir: str, - seed: int, - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for COWC Counting based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class - seed: The seed value to use when doing the dataset random_split - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 # scale to [0, 1] - sample["label"] = sample["label"].float() - return sample - - def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - COWCCounting(self.root_dir, download=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - Args: - stage: stage to set up - """ - train_val_dataset = COWCCounting( - self.root_dir, split="train", transforms=self.custom_transform - ) - self.test_dataset = COWCCounting( - self.root_dir, split="test", transforms=self.custom_transform - ) - self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], - generator=Generator().manual_seed(self.seed), - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index ac3ba679098..48e7e176553 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -8,6 +8,7 @@ from functools import lru_cache from typing import Callable, Dict, List, Optional, Tuple +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -102,6 +103,8 @@ class CV4AKenyaCropType(VisionDataset): "CLD", ) + RGB_BANDS = ["B04", "B03", "B02"] + # Same for all tiles tile_height = 3035 tile_width = 2016 @@ -400,3 +403,68 @@ def _download(self, api_key: Optional[str] = None) -> None: target_archive_path = os.path.join(self.root, self.target_meta["filename"]) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + time_step: int = 0, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + time_step: time step at which to access image, beginning with 0 + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + if "prediction" in sample: + prediction = sample["prediction"] + n_cols = 3 + else: + n_cols = 2 + + image, mask = sample["image"], sample["mask"] + + assert time_step <= image.shape[0] - 1, ( + "The specified time step" + " does not exist, image only contains {} time" + " instances." + ).format(image.shape[0]) + + image = image[time_step, rgb_indices, :, :] + + fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) + + axs[0].imshow(image.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(mask) + axs[1].axis("off") + + if "prediction" in sample: + axs[2].imshow(prediction) + axs[2].axis("off") + if show_titles: + axs[2].set_title("Prediction") + + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 37c20ca42d6..0229f1f85dc 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -10,20 +10,13 @@ import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image -from sklearn.model_selection import GroupShuffleSplit from torch import Tensor -from torch.utils.data import DataLoader, Subset from .geo import VisionDataset from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class TropicalCycloneWindEstimation(VisionDataset): """Tropical Cyclone Wind Estimation Competition dataset. @@ -254,157 +247,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class CycloneDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the NASA Cyclone dataset. - - Implements 80/20 train/val splits based on hurricane storm ids. - See :func:`setup` for more details. - """ - - def __init__( - self, - root_dir: str, - seed: int, - batch_size: int = 64, - num_workers: int = 0, - api_key: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the - TropicalCycloneWindEstimation Datasets classes - seed: The seed value to use when doing the sklearn based GroupShuffleSplit - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - api_key: The RadiantEarth MLHub API key to use if the dataset needs to be - downloaded - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - self.api_key = api_key - - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 # scale to [0,1] - sample["image"] = ( - sample["image"].unsqueeze(0).repeat(3, 1, 1) - ) # convert to 3 channel - sample["label"] = torch.as_tensor( # type: ignore[attr-defined] - sample["label"] - ).float() - - return sample - - def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - TropicalCycloneWindEstimation( - self.root_dir, - split="train", - transforms=self.custom_transform, - download=self.api_key is not None, - api_key=self.api_key, - ) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - We split samples between train/val by the ``storm_id`` property. I.e. all - samples with the same ``storm_id`` value will be either in the train or the val - split. This is important to test one type of generalizability -- given a new - storm, can we predict its windspeed. The test set, however, contains *some* - storms from the training set (specifically, the latter parts of the storms) as - well as some novel storms. - - Args: - stage: stage to set up - """ - self.all_train_dataset = TropicalCycloneWindEstimation( - self.root_dir, - split="train", - transforms=self.custom_transform, - download=False, - ) - - self.all_test_dataset = TropicalCycloneWindEstimation( - self.root_dir, - split="test", - transforms=self.custom_transform, - download=False, - ) - - storm_ids = [] - for item in self.all_train_dataset.collection: - storm_id = item["href"].split("/")[0].split("_")[-2] - storm_ids.append(storm_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( - storm_ids, groups=storm_ids - ) - ) - - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index bb10da22bff..dbcf667ce95 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -5,16 +5,13 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image -from torch import Generator, Tensor # type: ignore[attr-defined] -from torch.utils.data import DataLoader, random_split -from torchvision.transforms import Normalize +from torch import Tensor from .geo import VisionDataset from .utils import download_and_extract_archive @@ -320,140 +317,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class ETCI2021DataModule(pl.LightningDataModule): - """LightningDataModule implementation for the ETCI2021 dataset. - - Splits the existing train split from the dataset into train/val with 80/20 - proportions, then uses the existing val dataset as the test data. - - .. versionadded:: 0.2 - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1] - ) - - def __init__( - self, - root_dir: str, - seed: int = 0, - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for ETCI2021 based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes - seed: The seed value to use when doing the dataset random_split - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Notably, moves the given water mask to act as an input layer. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - image = sample["image"] - water_mask = sample["mask"][0].unsqueeze(0) - flood_mask = sample["mask"][1] - flood_mask = (flood_mask > 0).long() - - sample["image"] = torch.cat( # type: ignore[attr-defined] - [image, water_mask], dim=0 - ).float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - sample["mask"] = flood_mask - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - ETCI2021(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_val_dataset = ETCI2021( - self.root_dir, split="train", transforms=self.preprocess - ) - self.test_dataset = ETCI2021( - self.root_dir, split="val", transforms=self.preprocess - ) - - size_train_val = len(train_val_dataset) - size_train = int(0.8 * size_train_val) - size_val = size_train_val - size_train - - self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [size_train, size_val], - generator=Generator().manual_seed(self.seed), - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 9ba06c1c683..e1e14072039 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,15 +4,11 @@ """EuroSAT dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl -import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset from .utils import check_integrity, download_url, extract_archive, rasterio_loader @@ -229,138 +225,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class EuroSATDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the EuroSAT dataset. - - Uses the train/val/test splits from the dataset. - - .. versionadded:: 0.2 - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798, - ] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042, - ] - ) - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for EuroSAT based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - EuroSAT(self.root_dir) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms) - self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms) - self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index e75f27be35f..c8e2184f23a 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -11,33 +11,12 @@ import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import check_integrity, dataset_split, extract_archive from .geo import VisionDataset - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - - -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable number of boxes. - - Args: - batch: list of sample dicts return by dataset - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - return output +from .utils import check_integrity, extract_archive def parse_pascal_voc(path: str) -> Dict[str, Any]: @@ -350,102 +329,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class FAIR1MDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the FAIR1M dataset.""" - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - test_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for FAIR1M based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.test_split_pct = test_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = FAIR1M(self.root_dir, transforms=transforms) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 5c09a2d8930..93df5d7f482 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -134,6 +134,8 @@ def __and__(self, other: "GeoDataset") -> "IntersectionDataset": Raises: ValueError: if other is not a :class:`GeoDataset` + + .. versionadded:: 0.2 """ return IntersectionDataset(self, other) @@ -148,6 +150,8 @@ def __or__(self, other: "GeoDataset") -> "UnionDataset": Raises: ValueError: if other is not a :class:`GeoDataset` + + .. versionadded:: 0.2 """ return UnionDataset(self, other) @@ -222,6 +226,8 @@ def crs(self) -> CRS: Returns: the :term:`coordinate reference system (CRS)` + + .. versionadded:: 0.2 """ return self._crs @@ -233,6 +239,8 @@ def crs(self, new_crs: CRS) -> None: Args: new_crs: new :term:`coordinate reference system (CRS)` + + .. versionadded:: 0.2 """ if new_crs == self._crs: return @@ -810,6 +818,8 @@ class IntersectionDataset(GeoDataset): .. code-block:: python dataset = landsat & cdl + + .. versionadded:: 0.2 """ def __init__( @@ -920,6 +930,8 @@ class UnionDataset(GeoDataset): .. code-block:: python dataset = landsat7 | landsat8 + + .. versionadded:: 0.2 """ def __init__( diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index e579d668d63..2fecb5d6e10 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -6,24 +6,18 @@ import hashlib import os from functools import lru_cache -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from matplotlib.colors import ListedColormap from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader from .geo import VisionDataset from .utils import check_integrity, download_and_extract_archive, working_dir -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class LandCoverAI(VisionDataset): r"""LandCover.ai dataset. @@ -266,110 +260,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class LandCoverAIDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the LandCover.ai dataset. - - Uses the train/val/test splits from the dataset. - """ - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for LandCover.ai based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 - - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"].float().unsqueeze(0) + 1 - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - _ = LandCoverAI(self.root_dir, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = self.preprocess - val_test_transforms = self.preprocess - - self.train_dataset = LandCoverAI( - self.root_dir, split="train", transforms=train_transforms - ) - - self.val_dataset = LandCoverAI( - self.root_dir, split="val", transforms=val_test_transforms - ) - - self.test_dataset = LandCoverAI( - self.root_dir, split="test", transforms=val_test_transforms - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 24d76ca6594..9098a23b585 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -7,6 +7,7 @@ import os from typing import Callable, Dict, List, Optional +import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image @@ -47,6 +48,7 @@ class LEVIRCDPlus(VisionDataset): url = "https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81" md5 = "1adf156f628aa32fb2e8fe6cada16c04" filename = "LEVIR-CD+.zip" + directory = "LEVIR-CD+" splits = ["train", "test"] def __init__( @@ -88,7 +90,7 @@ def __init__( + "You can use download=True to download it" ) - self.files = self._load_files(self.root, self.split) + self.files = self._load_files(self.root, self.directory, self.split) def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. @@ -120,23 +122,26 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: + def _load_files( + self, root: str, directory: str, split: str + ) -> List[Dict[str, str]]: """Return the paths of the files in the dataset. Args: root: root dir of dataset + directory: sub directory LEVIR-CD+ split: subset of dataset, one of [train, test] Returns: list of dicts containing paths for each pair of image1, image2, mask """ files = [] - images = glob.glob(os.path.join(root, split, "A", "*.png")) + images = glob.glob(os.path.join(root, directory, split, "A", "*.png")) images = sorted([os.path.basename(image) for image in images]) for image in images: - image1 = os.path.join(root, split, "A", image) - image2 = os.path.join(root, split, "B", image) - mask = os.path.join(root, split, "label", image) + image1 = os.path.join(root, directory, split, "A", image) + image2 = os.path.join(root, directory, split, "B", image) + mask = os.path.join(root, directory, split, "label", image) files.append(dict(image1=image1, image2=image2, mask=mask)) return files @@ -181,7 +186,7 @@ def _check_integrity(self) -> bool: True if the dataset directories and split files are found, else False """ for filename in self.splits: - filepath = os.path.join(self.root, filename) + filepath = os.path.join(self.root, self.directory, filename) if not os.path.exists(filepath): return False return True @@ -202,3 +207,53 @@ def _download(self) -> None: filename=self.filename, md5=self.md5 if self.checksum else None, ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"]) + ncols = 3 + + if "prediction" in sample: + prediction = sample["prediction"] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(image1.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(image2.permute(1, 2, 0)) + axs[1].axis("off") + axs[2].imshow(mask) + axs[2].axis("off") + + if "prediction" in sample: + axs[3].imshow(prediction) + axs[3].axis("off") + if show_titles: + axs[3].set_title("Prediction") + + if show_titles: + axs[0].set_title("Image 1") + axs[1].set_title("Image 2") + axs[2].set_title("Mask") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index b3a0a52e8ca..30fe98adfb4 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -5,23 +5,17 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader from .geo import VisionDataset from .utils import download_and_extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class LoveDA(VisionDataset): """LoveDA dataset. @@ -305,117 +299,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class LoveDADataModule(pl.LightningDataModule): - """LightningDataModule implementation for the LoveDA dataset. - - Uses the train/val/test splits from the dataset. - """ - - def __init__( - self, - root_dir: str, - scene: List[str], - batch_size: int = 32, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for LoveDA based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to LoveDA Dataset classes - scene: specify whether to load only 'urban', only 'rural' or both - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.scene = scene - self.batch_size = batch_size - self.num_workers = num_workers - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"] / 255.0 - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - _ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = self.preprocess - val_test_transforms = self.preprocess - - self.train_dataset = LoveDA( - self.root_dir, split="train", scene=self.scene, transforms=train_transforms - ) - - self.val_dataset = LoveDA( - self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms - ) - - self.test_dataset = LoveDA( - self.root_dir, - split="test", - scene=self.scene, - transforms=val_test_transforms, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index 02cfe1e33f6..b6b4bceceb3 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -3,20 +3,7 @@ """National Agriculture Imagery Program (NAIP) dataset.""" -from typing import Any, Dict, Optional - -import pytorch_lightning as pl -from torch.utils.data import DataLoader - -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler -from .chesapeake import Chesapeake13 from .geo import RasterDataset -from .utils import BoundingBox, stack_samples - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" class NAIP(RasterDataset): @@ -55,147 +42,3 @@ class NAIP(RasterDataset): # Plotting all_bands = ["R", "G", "B", "NIR"] rgb_bands = ["R", "G", "B"] - - -class NAIPChesapeakeDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the NAIP and Chesapeake datasets. - - Uses the train/val/test splits from the dataset. - """ - - # TODO: tune these hyperparams - length = 1000 - stride = 128 - - def __init__( - self, - naip_root_dir: str, - chesapeake_root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - patch_size: int = 256, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. - - Args: - naip_root_dir: directory containing NAIP data - chesapeake_root_dir: directory containing Chesapeake data - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - patch_size: size of patches to sample - """ - super().__init__() # type: ignore[no-untyped-call] - self.naip_root_dir = naip_root_dir - self.chesapeake_root_dir = chesapeake_root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.patch_size = patch_size - - def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the NAIP Dataset. - - Args: - sample: NAIP image dictionary - - Returns: - preprocessed NAIP data - """ - sample["image"] = sample["image"] / 255.0 - sample["image"] = sample["image"].float() - return sample - - def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Chesapeake Dataset. - - Args: - sample: Chesapeake mask dictionary - - Returns: - preprocessed Chesapeake data - """ - sample["mask"] = sample["mask"].long()[0] - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: state to set up - """ - # TODO: these transforms will be applied independently, this won't work if we - # add things like random horizontal flip - chesapeake = Chesapeake13( - self.chesapeake_root_dir, transforms=self.chesapeake_transform - ) - naip = NAIP( - self.naip_root_dir, - chesapeake.crs, - chesapeake.res, - transforms=self.naip_transform, - ) - self.dataset = chesapeake & naip - - # TODO: figure out better train/val/test split - roi = self.dataset.bounds - midx = roi.minx + (roi.maxx - roi.minx) / 2 - midy = roi.miny + (roi.maxy - roi.miny) / 2 - train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) - val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) - test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) - - self.train_sampler = RandomBatchGeoSampler( - naip, self.patch_size, self.batch_size, self.length, train_roi - ) - self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi) - self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.dataset, - batch_sampler=self.train_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.val_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.test_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index bd239e65847..39f5ff62679 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -4,39 +4,17 @@ """NASA Marine Debris dataset.""" import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import rasterio import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from torchvision.utils import draw_bounding_boxes from .geo import VisionDataset -from .utils import dataset_split, download_radiant_mlhub_dataset, extract_archive - -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - - -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable boxes. - - Args: - batch: list of sample dicts return by dataset - - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - return output +from .utils import download_radiant_mlhub_dataset, extract_archive class NASAMarineDebris(VisionDataset): @@ -70,7 +48,7 @@ class NASAMarineDebris(VisionDataset): * `radiant-mlhub `_ to download the imagery and labels from the Radiant Earth MLHub - .. versionadded: 0.2 + .. versionadded:: 0.2 """ dataset_id = "nasa_marine_debris" @@ -279,109 +257,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class NASAMarineDebrisDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the NASA Marine Debris dataset.""" - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - test_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to the Dataset class - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.test_split_pct = test_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - NASAMarineDebris(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = NASAMarineDebris(self.root_dir, transforms=transforms) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index c2f807b490e..555037d971f 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -5,25 +5,23 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Union -import kornia.augmentation as K import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch -from einops import repeat from matplotlib.figure import Figure from numpy import ndarray as Array from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate -from torchvision.transforms import Compose, Normalize -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import download_url, extract_archive, sort_sentinel2_bands +from .utils import ( + download_url, + draw_semantic_segmentation_masks, + extract_archive, + sort_sentinel2_bands, +) class OSCD(VisionDataset): @@ -62,8 +60,17 @@ class OSCD(VisionDataset): "https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download" ), } - - md5 = "7383412da7ece1dca1c12dc92ac77f09" + md5s = { + "Onera Satellite Change Detection dataset - Images.zip": ( + "c50d4a2941da64e03a47ac4dec63d915" + ), + "Onera Satellite Change Detection dataset - Train Labels.zip": ( + "4d2965af8170c705ebad3d6ee71b6990" + ), + "Onera Satellite Change Detection dataset - Test Labels.zip": ( + "8177d437793c522653c442aa4e66c617" + ), + } zipfile_glob = "*Onera*.zip" filename_glob = "*Onera*" @@ -256,7 +263,7 @@ def _download(self) -> None: self.urls[f_name], self.root, filename=f_name, - md5=self.md5 if self.checksum else None, + md5=self.md5s[f_name] if self.checksum else None, ) def _extract(self) -> None: @@ -317,202 +324,3 @@ def get_masked(img: Tensor) -> Array: # type: ignore[type-arg] plt.suptitle(suptitle) return fig - - -class OSCDDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the OSCD dataset. - - Uses the train/test splits from the dataset and further splits - the train split into train/val splits. - - .. versionadded: 0.2 - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [ - 1583.0741, - 1374.3202, - 1294.1616, - 1325.6158, - 1478.7408, - 1933.0822, - 2166.0608, - 2076.4868, - 2306.0652, - 690.9814, - 16.2360, - 2080.3347, - 1524.6930, - ] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [ - 52.1937, - 83.4168, - 105.6966, - 151.1401, - 147.4615, - 115.9289, - 123.1974, - 114.6483, - 141.4530, - 73.2758, - 4.8368, - 213.4821, - 179.4793, - ] - ) - - def __init__( - self, - root_dir: str, - bands: str = "all", - train_batch_size: int = 32, - num_workers: int = 0, - val_split_pct: float = 0.2, - patch_size: Tuple[int, int] = (64, 64), - num_patches_per_tile: int = 32, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for OSCD based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the OSCD Dataset classes - bands: "rgb" or "all" - train_batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1) - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - patch_size: Size of random patch from image and mask (height, width) - num_patches_per_tile: number of random patches per sample - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.bands = bands - self.train_batch_size = train_batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.patch_size = patch_size - self.num_patches_per_tile = num_patches_per_tile - - if bands == "rgb": - self.band_means = self.band_means[[3, 2, 1], None, None] - self.band_stds = self.band_stds[[3, 2, 1], None, None] - else: - self.band_means = self.band_means[:, None, None] - self.band_stds = self.band_stds[:, None, None] - - self.norm = Normalize(self.band_means, self.band_stds) - self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True - ) - self.padto = K.PadTo((1280, 1280)) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"] - sample["image"] = self.norm(sample["image"]) - sample["image"] = torch.flatten( # type: ignore[attr-defined] - sample["image"], 0, 1 - ) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - OSCD(self.root_dir, split="train", bands=self.bands, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - """ - - def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: - images, masks = [], [] - for i in range(self.num_patches_per_tile): - mask = repeat(sample["mask"], "h w -> t h w", t=2).float() - image, mask = self.rcrop(sample["image"], mask) - mask = mask.squeeze()[0] - images.append(image.squeeze()) - masks.append(mask.long()) - sample["image"] = torch.stack(images) - sample["mask"] = torch.stack(masks) - return sample - - def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = self.padto(sample["image"])[0] - sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] - return sample - - train_transforms = Compose([self.preprocess, n_random_crop]) - # for testing and validation we pad all inputs to a fixed size to avoid issues - # with the upsampling paths in encoder-decoder architectures - test_transforms = Compose([self.preprocess, pad_to]) - - train_dataset = OSCD( - self.root_dir, split="train", bands=self.bands, transforms=train_transforms - ) - if self.val_split_pct > 0.0: - val_dataset = OSCD( - self.root_dir, - split="train", - bands=self.bands, - transforms=test_transforms, - ) - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - self.val_dataset.dataset = val_dataset - else: - self.train_dataset = train_dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = OSCD( - self.root_dir, split="test", bands=self.bands, transforms=test_transforms - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - - def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] - batch - ) - r_batch["image"] = torch.flatten( # type: ignore[attr-defined] - r_batch["image"], 0, 1 - ) - r_batch["mask"] = torch.flatten( # type: ignore[attr-defined] - r_batch["mask"], 0, 1 - ) - return r_batch - - return DataLoader( - self.train_dataset, - batch_size=self.train_batch_size, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 40149d0429f..a54e4b18f75 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -4,22 +4,23 @@ """Potsdam dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import rasterio import torch from matplotlib.figure import Figure from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import check_integrity, extract_archive, rgb_to_mask +from .utils import ( + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, + rgb_to_mask, +) class Potsdam2D(VisionDataset): @@ -293,111 +294,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class Potsdam2DDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the Potsdam2D dataset. - - Uses the train/test splits from the dataset. - - .. versionadded: 0.2 - """ - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for Potsdam2D based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = Potsdam2D(self.root_dir, "train", transforms=transforms) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 4b5c9560a0b..13117d54645 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -4,23 +4,15 @@ """RESISC45 dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl -import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset from .utils import download_url, extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class RESISC45(VisionClassificationDataset): """RESISC45 dataset. @@ -288,109 +280,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class RESISC45DataModule(pl.LightningDataModule): - """LightningDataModule implementation for the RESISC45 dataset. - - Uses the train/val/test splits from the dataset. - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [0.36801773, 0.38097873, 0.343583] - ) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [0.14540215, 0.13558227, 0.13203649] - ) - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for RESISC45 based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - RESISC45(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms) - self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms) - self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index f1cf8b2ad2c..87eedab72a3 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -4,23 +4,17 @@ """SEN12MS dataset.""" import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, Optional, Sequence +import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import rasterio import torch -from sklearn.model_selection import GroupShuffleSplit from torch import Tensor -from torch.utils.data import DataLoader, Subset from .geo import VisionDataset from .utils import check_integrity -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class SEN12MS(VisionDataset): """SEN12MS dataset. @@ -69,13 +63,72 @@ class SEN12MS(VisionDataset): This download will likely take several hours. """ # noqa: E501 - BAND_SETS: Dict[str, List[int]] = { - "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - "s1": [0, 1], - "s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - "s2-reduced": [3, 4, 5, 9, 12, 13], + # BAND_SETS: Dict[str, List[int]] = { + # "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + # "s1": [0, 1], + # "s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + # "s2-reduced": [3, 4, 5, 9, 12, 13], + # } + + BAND_SETS: Dict[str, Sequence[str]] = { + "all": tuple( + [ + "VV", + "VH", + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ] + ), + "s1": ("VV", "VH"), + "s2-all": ( + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ), + "s2-reduced": ("B02", "B03", "B04", "B08", "B10", "B11"), } + band_names = ( + "VV", + "VH", + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B10", + "B11", + "B12", + ) + + RGB_BANDS = ["B04", "B03", "B02"] + filenames = [ "ROIs1158_spring_lc.tar.gz", "ROIs1158_spring_s1.tar.gz", @@ -121,7 +174,7 @@ def __init__( self, root: str = "data", split: str = "train", - bands: List[int] = BAND_SETS["all"], + bands: Sequence[str] = BAND_SETS["all"], transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: @@ -147,9 +200,14 @@ def __init__( """ assert split in ["train", "test"] + self._validate_bands(bands) + self.band_indices = torch.tensor( # type: ignore[attr-defined] + [self.band_names.index(b) for b in bands] + ).long() + self.bands = bands + self.root = root self.split = split - self.bands = torch.tensor(bands).long() # type: ignore[attr-defined] self.transforms = transforms self.checksum = checksum @@ -180,7 +238,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: image = torch.cat(tensors=[s1, s2], dim=0) # type: ignore[attr-defined] image = torch.index_select( # type: ignore[attr-defined] - image, dim=0, index=self.bands + image, dim=0, index=self.band_indices ) sample: Dict[str, Tensor] = {"image": image, "mask": lc} @@ -223,6 +281,21 @@ def _load_raster(self, filename: str, source: str) -> Tensor: tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] return tensor + def _validate_bands(self, bands: Sequence[str]) -> None: + """Validate list of bands. + + Args: + bands: user-provided tuple of bands to load + + Raises: + AssertionError: if ``bands`` is not a tuple + ValueError: if an invalid band name is provided + """ + assert isinstance(bands, tuple), "The list of bands must be a tuple" + for band in bands: + if band not in self.band_names: + raise ValueError(f"'{band}' is an invalid band name.") + def _check_integrity_light(self) -> bool: """Checks the integrity of the dataset structure. @@ -247,187 +320,57 @@ def _check_integrity(self) -> bool: return False return True - -class SEN12MSDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the SEN12MS dataset. - - Implements 80/20 geographic train/val splits and uses the test split from the - classification dataset definitions. See :func:`setup` for more details. - - Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See - https://arxiv.org/abs/2002.08254. - """ - - #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader - #: here https://github.com/lukasliebel/dfc2020_baseline. - DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined] - [ - 0, # maps 0s to 0 - 1, # maps 1s to 1 - 1, # maps 2s to 1 - 1, # ... - 1, - 1, - 2, - 2, - 3, - 3, - 4, - 5, - 6, - 7, - 6, - 8, - 9, - 10, - ] - ) - - def __init__( + def plot( self, - root_dir: str, - seed: int, - band_set: str = "all", - batch_size: int = 64, - num_workers: int = 0, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for SEN12MS based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes - seed: The seed value to use when doing the sklearn based ShuffleSplit - band_set: The subset of S1/S2 bands to use. Options are: "all", - "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: - B2, B3, B4, B8, B11, and B12. - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - assert band_set in SEN12MS.BAND_SETS.keys() - - self.root_dir = root_dir - self.seed = seed - self.band_set = band_set - self.band_indices = SEN12MS.BAND_SETS[band_set] - self.batch_size = batch_size - self.num_workers = num_workers - - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - - if self.band_set == "all": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 - elif self.band_set == "s1": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - else: - sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 - - sample["mask"] = sample["mask"][0, :, :].long() - sample["mask"] = torch.take( # type: ignore[attr-defined] - self.DFC2020_CLASS_MAPPING, sample["mask"] - ) - - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - We split samples between train and val geographically with proportions of 80/20. - This mimics the geographic test set split. + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. Args: - stage: stage to set up - """ - season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - - self.all_train_dataset = SEN12MS( - self.root_dir, - split="train", - bands=self.band_indices, - transforms=self.custom_transform, - checksum=False, - ) - - self.all_test_dataset = SEN12MS( - self.root_dir, - split="test", - bands=self.band_indices, - transforms=self.custom_transform, - checksum=False, - ) - - # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" - # This patch will belong to the scene that is uniquelly identified by its - # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply - # give each season a large number and representing a `unique_scene_id` as - # `season_id + scene_id`. - scenes = [] - for scene_fn in self.all_train_dataset.ids: - parts = scene_fn.split("_") - season_id = season_to_int[parts[1]] - scene_id = int(parts[3]) - scenes.append(season_id + scene_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( - scenes, groups=scenes - ) - ) + sample: a sample return by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. + Returns; + a matplotlib Figure with the rendered sample - Returns: - validation data loader + .. versionadded:: 0.2 """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + rgb_indices = [] + for band in self.RGB_BANDS: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + image, mask = sample["image"][rgb_indices, ...], sample["mask"][0] + ncols = 2 + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = sample["prediction"][0] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(image.permute(1, 2, 0)) + axs[0].axis("off") + axs[1].imshow(mask) + axs[1].axis("off") + + if showing_predictions: + axs[2].imshow(prediction) + axs[2].axis("off") + + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + if showing_predictions: + axs[2].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 606a73e62c0..aaee0ecd2d3 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -4,23 +4,16 @@ """So2Sat dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from .geo import VisionDataset from .utils import check_integrity, percentile_normalization -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class So2Sat(VisionDataset): """So2Sat dataset. @@ -250,211 +243,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class So2SatDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the So2Sat dataset. - - Uses the train/val/test splits from the dataset. - """ - - band_means = torch.tensor( # type: ignore[attr-defined] - [ - -3.591224256609313e-05, - -7.658561276843396e-06, - 5.9373857475971184e-05, - 2.5166231537121083e-05, - 0.04420110659759328, - 0.25761027084996196, - 0.0007556743372573258, - 0.0013503466830024448, - 0.12375696117681859, - 0.1092774636368323, - 0.1010855203267882, - 0.1142398616114001, - 0.1592656692023089, - 0.18147236008771792, - 0.1745740312291377, - 0.19501607349635292, - 0.15428468872076637, - 0.10905050699570007, - ] - ).reshape(18, 1, 1) - - band_stds = torch.tensor( # type: ignore[attr-defined] - [ - 0.17555201137417686, - 0.17556463274968204, - 0.45998793417834255, - 0.455988755730148, - 2.8559909213125763, - 8.324800606439833, - 2.4498757382563103, - 1.4647352984509094, - 0.03958795985905458, - 0.047778262752410296, - 0.06636616706371974, - 0.06358874912497474, - 0.07744387147984592, - 0.09101635085921553, - 0.09218466562387101, - 0.10164581233948201, - 0.09991773043519253, - 0.08780632509122865, - ] - ).reshape(18, 1, 1) - - # this reorders the bands to put S2 RGB first, then remainder of S2, then S1 - reindex_to_rgb_first = [ - 10, - 9, - 8, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - # 0, - # 1, - # 2, - # 3, - # 4, - # 5, - # 6, - # 7, - ] - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - bands: str = "rgb", - unsupervised_mode: bool = False, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for So2Sat based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - bands: Either "rgb" or "s2" - unsupervised_mode: Makes the train dataloader return imagery from the train, - val, and test sets - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.bands = bands - self.unsupervised_mode = unsupervised_mode - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - # sample["image"] = (sample["image"] - self.band_means) / self.band_stds - sample["image"] = sample["image"].float() - sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] - - if self.bands == "rgb": - sample["image"] = sample["image"][:3, :, :] - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - So2Sat(self.root_dir, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = Compose([self.preprocess]) - val_test_transforms = self.preprocess - - if not self.unsupervised_mode: - - self.train_dataset = So2Sat( - self.root_dir, split="train", transforms=train_transforms - ) - - self.val_dataset = So2Sat( - self.root_dir, split="validation", transforms=val_test_transforms - ) - - self.test_dataset = So2Sat( - self.root_dir, split="test", transforms=val_test_transforms - ) - - else: - - temp_train = So2Sat( - self.root_dir, split="train", transforms=train_transforms - ) - - self.val_dataset = So2Sat( - self.root_dir, split="validation", transforms=train_transforms - ) - - self.test_dataset = So2Sat( - self.root_dir, split="test", transforms=train_transforms - ) - - self.train_dataset = cast( - So2Sat, temp_train + self.val_dataset + self.test_dataset - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 41cb1f2ff6c..38ac5c6a150 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -24,8 +24,8 @@ from rasterio.transform import Affine from torch import Tensor -from torchgeo.datasets.geo import VisionDataset -from torchgeo.datasets.utils import ( +from .geo import VisionDataset +from .utils import ( check_integrity, download_radiant_mlhub_collection, extract_archive, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 431b526b756..21b09e32a1d 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -3,24 +3,15 @@ """UC Merced dataset.""" import os -from typing import Any, Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, cast import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl -import torch -import torchvision from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from .geo import VisionClassificationDataset from .utils import check_integrity, download_url, extract_archive -# https://github.com/pytorch/pytorch/issues/60979 -# https://github.com/pytorch/pytorch/pull/61045 -DataLoader.__module__ = "torch.utils.data" - class UCMerced(VisionClassificationDataset): """UC Merced dataset. @@ -251,110 +242,3 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) return fig - - -class UCMercedDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the UC Merced dataset. - - Uses random train/val/test splits. - """ - - band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined] - - band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined] - - def __init__( - self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any - ) -> None: - """Initialize a LightningDataModule for UCMerced based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - c, h, w = sample["image"].shape - if h != 256 or w != 256: - sample["image"] = torchvision.transforms.functional.resize( - sample["image"], size=(256, 256) - ) - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - UCMerced(self.root_dir, download=False, checksum=False) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms) - self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms) - self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index fceb6201a38..472977af848 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -32,7 +32,6 @@ import rasterio import torch from torch import Tensor -from torch.utils.data import Dataset, Subset, random_split from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks @@ -48,7 +47,6 @@ "concat_samples", "merge_samples", "rasterio_loader", - "dataset_split", "sort_sentinel2_bands", "draw_semantic_segmentation_masks", "rgb_to_mask", @@ -223,6 +221,8 @@ def __post_init__(self) -> None: Raises: ValueError: if bounding box is invalid (minx > maxx, miny > maxy, or mint > maxt) + + .. versionadded:: 0.2 """ if self.minx > self.maxx: raise ValueError( @@ -276,6 +276,8 @@ def __contains__(self, other: "BoundingBox") -> bool: Returns: True if other is within this bounding box, else False + + .. versionadded:: 0.2 """ return ( (self.minx <= other.minx <= self.maxx) @@ -294,6 +296,8 @@ def __or__(self, other: "BoundingBox") -> "BoundingBox": Returns: the minimum bounding box that contains both self and other + + .. versionadded:: 0.2 """ return BoundingBox( min(self.minx, other.minx), @@ -315,6 +319,8 @@ def __and__(self, other: "BoundingBox") -> "BoundingBox": Raises: ValueError: if self and other do not intersect + + .. versionadded:: 0.2 """ try: return BoundingBox( @@ -431,6 +437,8 @@ def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List Returns: a dictionary of lists + + .. versionadded:: 0.2 """ collated = collections.defaultdict(list) for sample in samples: @@ -450,6 +458,8 @@ def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: Returns: a single sample + + .. versionadded:: 0.2 """ collated: Dict[Any, Any] = _list_dict_to_dict_list(samples) for key, value in collated.items(): @@ -468,6 +478,8 @@ def concat_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: Returns: a single sample + + .. versionadded:: 0.2 """ collated: Dict[Any, Any] = _list_dict_to_dict_list(samples) for key, value in collated.items(): @@ -488,6 +500,8 @@ def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: Returns: a single sample + + .. versionadded:: 0.2 """ collated: Dict[Any, Any] = {} for sample in samples: @@ -519,31 +533,6 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg] return array -def dataset_split( - dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None -) -> List[Subset[Any]]: - """Split a torch Dataset into train/val/test sets. - - If ``test_pct`` is not set then only train and validation splits are returned. - - Args: - dataset: dataset to be split into train/val or train/val/test subsets - val_pct: percentage of samples to be in validation set - test_pct: (Optional) percentage of samples to be in test set - Returns: - a list of the subset datasets. Either [train, val] or [train, val, test] - """ - if test_pct is None: - val_length = int(len(dataset) * val_pct) - train_length = len(dataset) - val_length - return random_split(dataset, [train_length, val_length]) - else: - val_length = int(len(dataset) * val_pct) - test_length = int(len(dataset) * test_pct) - train_length = len(dataset) - (val_length + test_length) - return random_split(dataset, [train_length, val_length, test_length]) - - def sort_sentinel2_bands(x: str) -> str: """Sort Sentinel-2 band files in the correct order.""" x = os.path.basename(x).split("_")[-1] diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index f95e8e72d48..e46b9af9ccd 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -4,21 +4,22 @@ """Vaihingen dataset.""" import os -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from matplotlib.figure import Figure from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import check_integrity, extract_archive, rgb_to_mask +from .utils import ( + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, + rgb_to_mask, +) class Vaihingen2D(VisionDataset): @@ -50,7 +51,7 @@ class Vaihingen2D(VisionDataset): * https://doi.org/10.5194/isprsannals-I-3-293-2012 - .. versionadded: 0.2 + .. versionadded:: 0.2 """ # noqa: E501 filenames = [ @@ -293,111 +294,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class Vaihingen2DDataModule(pl.LightningDataModule): - """LightningDataModule implementation for the Vaihingen2D dataset. - - Uses the train/test splits from the dataset. - - .. versionadded: 0.2 - """ - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for Vaihingen2D based DataLoaders. - - Args: - root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index c4e7774e04d..69d2421ee20 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -5,20 +5,16 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np -import pytorch_lightning as pl import torch from PIL import Image from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks from .geo import VisionDataset -from .utils import check_integrity, extract_archive +from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive class XView2(VisionDataset): @@ -48,7 +44,7 @@ class XView2(VisionDataset): * https://arxiv.org/abs/1911.09296 - .. versionadded: 0.2 + .. versionadded:: 0.2 """ metadata = { @@ -282,111 +278,3 @@ def plot( plt.suptitle(suptitle) return fig - - -class XView2DataModule(pl.LightningDataModule): - """LightningDataModule implementation for the xView2 dataset. - - Uses the train/val/test splits from the dataset. - - .. versionadded: 0.2 - """ - - def __init__( - self, - root_dir: str, - batch_size: int = 64, - num_workers: int = 0, - val_split_pct: float = 0.2, - **kwargs: Any, - ) -> None: - """Initialize a LightningDataModule for xView2 based DataLoaders. - - Args: - root_dir: The ``root`` arugment to pass to the xView2 Dataset classes - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - """ - super().__init__() # type: ignore[no-untyped-call] - self.root_dir = root_dir - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - dataset = XView2(self.root_dir, "train", transforms=transforms) - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset # type: ignore[assignment] - self.val_dataset = None # type: ignore[assignment] - - self.test_dataset = XView2(self.root_dir, "test", transforms=transforms) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - if self.val_split_pct == 0.0: - return self.train_dataloader() - else: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) diff --git a/torchgeo/losses/__init__.py b/torchgeo/losses/__init__.py new file mode 100644 index 00000000000..d2967cd97fd --- /dev/null +++ b/torchgeo/losses/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TorchGeo losses.""" + +from .qr import QRLoss, RQLoss + +__all__ = ("QRLoss", "RQLoss") + +# https://stackoverflow.com/questions/40018681 +for module in __all__: + globals()[module].__module__ = "torchgeo.losses" diff --git a/torchgeo/losses/qr.py b/torchgeo/losses/qr.py new file mode 100644 index 00000000000..e24d0d773cf --- /dev/null +++ b/torchgeo/losses/qr.py @@ -0,0 +1,78 @@ +"""Loss functions for learing on the prior.""" + +from typing import cast + +import torch +import torch.nn.functional as F +from torch.nn.modules import Module + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +Module.__module__ = "torch.nn" + + +class QRLoss(Module): + """The QR (forward) loss between class probabilities and predictions. + + This loss is defined in `'Resolving label uncertainty with implicit generative + models' `_. + + .. versionadded:: 0.2 + """ + + def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Computes the QR (forwards) loss on prior. + + Args: + probs: probabilities of predictions, expected shape B x C x H x W. + target: prior probabilities, expected shape B x C x H x W. + + Returns: + qr loss + """ + q = probs + q_bar = q.mean(dim=(0, 2, 3)) + qbar_log_S = (q_bar * torch.log(q_bar)).sum() # type: ignore[attr-defined] + + q_log_p = torch.einsum( # type: ignore[attr-defined] + "bcxy,bcxy->bxy", q, torch.log(target) # type: ignore[attr-defined] + ).mean() + + loss = qbar_log_S - q_log_p + return cast(torch.Tensor, loss) + + +class RQLoss(Module): + """The RQ (backwards) loss between class probabilities and predictions. + + This loss is defined in `'Resolving label uncertainty with implicit generative + models' `_. + + .. versionadded:: 0.2 + """ + + def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Computes the RQ (backwards) loss on prior. + + Args: + probs: probabilities of predictions, expected shape B x C x H x W + target: prior probabilities, expected shape B x C x H x W + + Returns: + qr loss + """ + q = probs + + # manually normalize due to https://github.com/pytorch/pytorch/issues/70100 + z = q / q.norm( # type: ignore[no-untyped-call] + p=1, dim=(0, 2, 3), keepdim=True + ).clamp_min(1e-12).expand_as(q) + r = F.normalize(z * target, p=1, dim=1) + + loss = torch.einsum( # type: ignore[attr-defined] + "bcxy,bcxy->bxy", + r, + torch.log(r) - torch.log(q), # type: ignore[attr-defined] + ).mean() + + return cast(torch.Tensor, loss) diff --git a/torchgeo/models/fcn.py b/torchgeo/models/fcn.py index d02db303976..6cde6da2c73 100644 --- a/torchgeo/models/fcn.py +++ b/torchgeo/models/fcn.py @@ -23,7 +23,7 @@ def __init__(self, in_channels: int, classes: int, num_filters: int = 64) -> Non classes: Number of filters in the final layer num_filters: Number of filters in each convolutional layer """ - super(FCN, self).__init__() + super().__init__() conv1 = nn.modules.Conv2d( in_channels, num_filters, kernel_size=3, stride=1, padding=1 diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 0dcd1a85aef..3dc5dcc4b8e 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -10,9 +10,7 @@ from rtree.index import Index, Property from torch.utils.data import Sampler -from torchgeo.datasets.geo import GeoDataset -from torchgeo.datasets.utils import BoundingBox - +from ..datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 1804d9a2d84..d507f698e3b 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -10,9 +10,7 @@ from rtree.index import Index, Property from torch.utils.data import Sampler -from torchgeo.datasets.geo import GeoDataset -from torchgeo.datasets.utils import BoundingBox - +from ..datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index b8aecd85a11..265859eeb06 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -6,7 +6,7 @@ import random from typing import Tuple, Union -from torchgeo.datasets.utils import BoundingBox +from ..datasets import BoundingBox def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index fb3565a7887..44e36a150b2 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -247,7 +247,7 @@ def __init__( model: Module, image_size: Tuple[int, int] = (256, 256), hidden_layer: Union[str, int] = -2, - input_channels: int = 4, + in_channels: int = 4, projection_size: int = 256, hidden_size: int = 4096, augment_fn: Optional[Module] = None, @@ -261,7 +261,7 @@ def __init__( image_size: the size of the training images hidden_layer: the hidden layer in ``model`` to attach the projection head to, can be the name of the layer or index of the layer - input_channels: number of input channels to the model + in_channels: number of input channels to the model projection_size: size of first layer of the projection MLP hidden_size: size of the hidden layer of the projection MLP augment_fn: an instance of a module that performs data augmentation @@ -277,7 +277,7 @@ def __init__( self.augment = augment_fn self.beta = beta - self.input_channels = input_channels + self.in_channels = in_channels self.encoder = EncoderWrapper( model, projection_size, hidden_size, layer=hidden_layer ) @@ -288,9 +288,7 @@ def __init__( # Perform a single forward pass to initialize the wrapper correctly self.encoder( - torch.zeros( # type: ignore[attr-defined] - 2, self.input_channels, *image_size - ) + torch.zeros(2, self.in_channels, *image_size) # type: ignore[attr-defined] ) def forward(self, x: Tensor) -> Tensor: @@ -315,21 +313,23 @@ class BYOLTask(LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" - input_channels = self.hparams["input_channels"] + in_channels = self.hparams["in_channels"] pretrained = self.hparams["imagenet_pretraining"] encoder = None - if self.hparams["encoder"] == "resnet18": + if self.hparams["encoder_name"] == "resnet18": encoder = resnet18(pretrained=pretrained) - elif self.hparams["encoder"] == "resnet50": + elif self.hparams["encoder_name"] == "resnet50": encoder = resnet50(pretrained=pretrained) else: - raise ValueError(f"Encoder type '{self.hparams['encoder']}' is not valid.") + raise ValueError( + f"Encoder type '{self.hparams['encoder_name']}' is not valid." + ) layer = encoder.conv1 # Creating new Conv2d layer new_layer = Conv2d( - in_channels=input_channels, + in_channels=in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, @@ -343,7 +343,7 @@ def config_task(self) -> None: ... # type: ignore[index] ] = Variable(layer.weight.clone(), requires_grad=True) # Copying the weights of the old layer to the extra channels - for i in range(input_channels - layer.in_channels): + for i in range(in_channels - layer.in_channels): channel = layer.in_channels + i new_layer.weight[:, channel : channel + 1, :, :].data[ ... # type: ignore[index] @@ -359,8 +359,8 @@ def __init__(self, **kwargs: Any) -> None: """Initialize a LightningModule for pre-training a model with BYOL. Keyword Args: - input_channels: number of channels on the input imagery - encoder: either "resnet18" or "resnet50" + in_channels: number of channels on the input imagery + encoder_name: either "resnet18" or "resnet50" imagenet_pretraining: bool indicating whether to use imagenet pretrained weights diff --git a/torchgeo/trainers/chesapeake.py b/torchgeo/trainers/chesapeake.py index 15a2e247dee..f1c78ef872a 100644 --- a/torchgeo/trainers/chesapeake.py +++ b/torchgeo/trainers/chesapeake.py @@ -35,7 +35,7 @@ class ChesapeakeCVPRSegmentationTask(SemanticSegmentationTask): """LightningModule for training models on the Chesapeake CVPR Land Cover dataset. - .. deprecated: 0.1 + .. deprecated:: 0.1 Use :class:`SemanticSegmentationTask` instead. """ diff --git a/torchgeo/trainers/landcoverai.py b/torchgeo/trainers/landcoverai.py index c95f0313523..7c9ac27856a 100644 --- a/torchgeo/trainers/landcoverai.py +++ b/torchgeo/trainers/landcoverai.py @@ -19,7 +19,7 @@ class LandCoverAISegmentationTask(SemanticSegmentationTask): """LightningModule for training models on the Landcover.AI Dataset. - .. deprecated: 0.1 + .. deprecated:: 0.1 Use :class:`SemanticSegmentationTask` instead. """ @@ -46,10 +46,10 @@ def training_step( # type: ignore[override] training loss """ x = batch["image"] - y = batch["mask"] + y = batch["mask"].float().unsqueeze(1) with torch.no_grad(): x, y = self.train_augmentations(x, y) - y = y.long().squeeze() + y = y.squeeze(1).long() y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) @@ -76,7 +76,7 @@ def validation_step( # type: ignore[override] batch_idx: Index of current batch """ x = batch["image"] - y = batch["mask"].long().squeeze() + y = batch["mask"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) @@ -120,7 +120,7 @@ def test_step( # type: ignore[override] batch_idx: Index of current batch """ x = batch["image"] - y = batch["mask"].long().squeeze() + y = batch["mask"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) diff --git a/torchgeo/trainers/naipchesapeake.py b/torchgeo/trainers/naipchesapeake.py index b5480627772..4139b00ce67 100644 --- a/torchgeo/trainers/naipchesapeake.py +++ b/torchgeo/trainers/naipchesapeake.py @@ -16,7 +16,7 @@ class NAIPChesapeakeSegmentationTask(SemanticSegmentationTask): """LightningModule for training models on the NAIP and Chesapeake datasets. - .. deprecated: 0.1 + .. deprecated:: 0.1 Use :class:`SemanticSegmentationTask` instead. """ diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index 7f7212b6e9e..dae8f55b5c3 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -12,7 +12,7 @@ import torch from torch import Tensor -from torch.nn import Module # type: ignore[attr-defined] +from torch.nn.modules import Module # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 @@ -90,7 +90,7 @@ def ndwi(green: Tensor, nir: Tensor) -> Tensor: return (green - nir) / ((green + nir) + _EPSILON) -class AppendNDBI(Module): # type: ignore[misc,name-defined] +class AppendNDBI(Module): """Normalized Difference Built-up Index (NDBI). If you use this dataset in your research, please cite the following paper: @@ -132,7 +132,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNBR(Module): # type: ignore[misc,name-defined] +class AppendNBR(Module): """Normalized Burn Ratio (NBR). .. versionadded:: 0.2.0 @@ -172,7 +172,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNDSI(Module): # type: ignore[misc,name-defined] +class AppendNDSI(Module): """Normalized Difference Snow Index (NDSI). If you use this dataset in your research, please cite the following paper: @@ -214,7 +214,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNDVI(Module): # type: ignore[misc,name-defined] +class AppendNDVI(Module): """Normalized Difference Vegetation Index (NDVI). If you use this dataset in your research, please cite the following paper: @@ -256,7 +256,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNDWI(Module): # type: ignore[misc,name-defined] +class AppendNDWI(Module): """Normalized Difference Water Index (NDWI). If you use this dataset in your research, please cite the following paper: diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index b1b8cda3af8..67be28c31bb 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,14 +8,14 @@ import kornia.augmentation as K import torch from torch import Tensor -from torch.nn import Module # type: ignore[attr-defined] +from torch.nn.modules import Module # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 Module.__module__ = "torch.nn" -class AugmentationSequential(Module): # type: ignore[misc] +class AugmentationSequential(Module): """Wrapper around kornia AugmentationSequential to handle input dicts.""" def __init__(self, *args: Module, data_keys: List[str]) -> None: diff --git a/train.py b/train.py index 9c4928589f2..eac21a07d7f 100755 --- a/train.py +++ b/train.py @@ -6,14 +6,67 @@ """torchgeo model training script.""" import os -from typing import Any, Dict, cast +from typing import Any, Dict, Tuple, Type, cast import pytorch_lightning as pl from omegaconf import DictConfig, OmegaConf from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from torchgeo import _TASK_TO_MODULES_MAPPING as TASK_TO_MODULES_MAPPING +from torchgeo.datamodules import ( + BigEarthNetDataModule, + ChesapeakeCVPRDataModule, + COWCCountingDataModule, + CycloneDataModule, + ETCI2021DataModule, + EuroSATDataModule, + LandCoverAIDataModule, + NAIPChesapeakeDataModule, + OSCDDataModule, + RESISC45DataModule, + SEN12MSDataModule, + So2SatDataModule, + UCMercedDataModule, +) +from torchgeo.trainers import ( + BYOLTask, + ClassificationTask, + MultiLabelClassificationTask, + RegressionTask, + SemanticSegmentationTask, +) +from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask +from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask +from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask +from torchgeo.trainers.resisc45 import RESISC45ClassificationTask + +TASK_TO_MODULES_MAPPING: Dict[ + str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] +] = { + "bigearthnet_all": (MultiLabelClassificationTask, BigEarthNetDataModule), + "bigearthnet_s1": (MultiLabelClassificationTask, BigEarthNetDataModule), + "bigearthnet_s2": (MultiLabelClassificationTask, BigEarthNetDataModule), + "byol": (BYOLTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr_5": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr_7": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr_prior": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), + "cowc_counting": (RegressionTask, COWCCountingDataModule), + "cyclone": (RegressionTask, CycloneDataModule), + "eurosat": (ClassificationTask, EuroSATDataModule), + "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), + "landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule), + "naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule), + "oscd_all": (SemanticSegmentationTask, OSCDDataModule), + "oscd_rgb": (SemanticSegmentationTask, OSCDDataModule), + "resisc45": (RESISC45ClassificationTask, RESISC45DataModule), + "sen12ms_all": (SemanticSegmentationTask, SEN12MSDataModule), + "sen12ms_s1": (SemanticSegmentationTask, SEN12MSDataModule), + "sen12ms_s2_all": (SemanticSegmentationTask, SEN12MSDataModule), + "sen12ms_s2_reduced": (SemanticSegmentationTask, SEN12MSDataModule), + "so2sat_supervised": (ClassificationTask, So2SatDataModule), + "so2sat_unsupervised": (ClassificationTask, So2SatDataModule), + "ucmerced": (ClassificationTask, UCMercedDataModule), +} def set_up_omegaconf() -> DictConfig: