Skip to content

Commit

Permalink
Add predict utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jun 16, 2022
1 parent b412f52 commit 2e280a2
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 5 deletions.
187 changes: 187 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""torchgeo model inference script."""

import argparse
import os
from typing import Dict, Tuple, Type, cast

import pytorch_lightning as pl
import rasterio as rio
import torch
from kornia.contrib import CombineTensorPatches
from omegaconf import OmegaConf

from torchgeo.datamodules import (
BigEarthNetDataModule,
ChesapeakeCVPRDataModule,
COWCCountingDataModule,
CycloneDataModule,
ETCI2021DataModule,
EuroSATDataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
OSCDDataModule,
RESISC45DataModule,
SEN12MSDataModule,
So2SatDataModule,
UCMercedDataModule,
)
from torchgeo.trainers import (
BYOLTask,
ClassificationTask,
MultiLabelClassificationTask,
RegressionTask,
SemanticSegmentationTask,
)

TASK_TO_MODULES_MAPPING: Dict[
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
] = {
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule),
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
"chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
"cowc_counting": (RegressionTask, COWCCountingDataModule),
"cyclone": (RegressionTask, CycloneDataModule),
"eurosat": (ClassificationTask, EuroSATDataModule),
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule),
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule),
"oscd": (SemanticSegmentationTask, OSCDDataModule),
"resisc45": (ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),
"ucmerced": (ClassificationTask, UCMercedDataModule),
}


def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None:
"""Write mask to specified output directory."""
output_path = os.path.join(output_dir, os.path.basename(input_filename))
with rio.open(input_filename) as src:
profile = src.profile
profile["count"] = 1
profile["dtype"] = "uint8"
mask = mask.cpu().numpy()
with rio.open(output_path, "w", **profile) as ds:
ds.write(mask)


def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None:
"""Main inference loop."""
os.makedirs(output_dir, exist_ok=True)

# Load checkpoint and config
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml"))
ckpt = os.path.join(config_dir, "last.ckpt")

# Load model
task_name = conf.experiment.task
datamodule: pl.LightningDataModule
task: pl.LightningModule
if task_name not in TASK_TO_MODULES_MAPPING:
raise ValueError(
f"experiment.task={task_name} is not recognized as a valid task"
)
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
task = task_class.load_from_checkpoint(ckpt)
task = task.to(device)
task.eval()

# Load datamodule and dataloader
conf.experiment.datamodule["predict_on"] = predict_on
datamodule = datamodule_class(**conf.experiment.datamodule)
datamodule.setup()
dataloader = datamodule.predict_dataloader()

if len(os.listdir(output_dir)) > 0:
if conf.program.overwrite:
print(
f"WARNING! The output directory, {output_dir}, already exists, "
+ "we will overwrite data in it!"
)
else:
raise FileExistsError(
f"The predictions directory, {output_dir}, already exists and isn't "
+ "empty. We don't want to overwrite any existing results, exiting..."
)

for i, batch in enumerate(dataloader):
x = batch["image"].to(device) # (N, B, C, H, W)
assert len(x.shape) in {4, 5}
if len(x.shape) == 5:
masks = []

def tensor_to_int(
tensor_tuple: Tuple[torch.Tensor, ...]
) -> Tuple[int, ...]:
"""Convert tuple of tensors to tuple of ints."""
return tuple(int(i.item()) for i in tensor_tuple)

original_shape = cast(
Tuple[int, int], tensor_to_int(batch["original_shape"])
)
patch_shape = cast(Tuple[int, int], tensor_to_int(batch["patch_shape"]))
padding = cast(Tuple[int, int], tensor_to_int(batch["padding"]))
patch_combine = CombineTensorPatches(
original_size=original_shape, window_size=patch_shape, unpadding=padding
)

for tile in x:
mask = task(tile)
mask = mask.argmax(dim=1)
masks.append(mask)

masks_arr = torch.stack(masks, dim=0)
masks_arr = masks_arr.unsqueeze(0)
masks_combined = patch_combine(masks_arr)[0]
filename = datamodule.predict_dataset.files[i]["image"]
write_mask(masks_combined, output_dir, filename)
else:
mask = task(x)
mask = mask.argmax(dim=1)
filename = datamodule.predict_dataset.files[i]["image"]
write_mask(mask, output_dir, filename)


if __name__ == "__main__":
# Taken from https://github.com/pangeo-data/cog-best-practices
_rasterio_best_practices = {
"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
"AWS_NO_SIGN_REQUEST": "YES",
"GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000",
"GDAL_SWATH_SIZE": "200000000",
"VSI_CURL_CACHE_SIZE": "200000000",
}
os.environ.update(_rasterio_best_practices)

parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",
type=str,
required=True,
help="Path to config-dir to load config and ckpt",
)

parser.add_argument(
"--predict_on",
type=str,
required=True,
help="Directory/Dataset to run inference on",
)

parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Path to output_directory to save predicted mask geotiffs",
)

parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
args = parser.parse_args()
main(args.config_dir, args.predict_on, args.output_dir, args.device)
16 changes: 11 additions & 5 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""InriaAerialImageLabeling datamodule."""

import os
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import kornia.augmentation as K
Expand All @@ -14,7 +15,7 @@
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate

from ..datasets import InriaAerialImageLabeling
from ..datasets import InriaAerialImageLabeling, PredictDataset
from ..samplers.utils import _to_tuple
from .utils import dataset_split

Expand Down Expand Up @@ -164,10 +165,15 @@ def setup(self, stage: Optional[str] = None) -> None:
self.val_dataset = train_dataset
self.test_dataset = train_dataset

assert self.predict_on == "test"
self.predict_dataset = InriaAerialImageLabeling(
self.root_dir, self.predict_on, transforms=test_transforms
)
if os.path.isdir(self.predict_on):
self.predict_dataset = PredictDataset(
self.predict_on, patch_size=self.patch_size, transforms=self.preprocess
)
else:
assert self.predict_on == "test"
self.predict_dataset = InriaAerialImageLabeling( # type: ignore[assignment]
self.root_dir, self.predict_on, transforms=test_transforms
)

def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from .usavars import USAVars
from .utils import (
BoundingBox,
PredictDataset,
concat_samples,
merge_samples,
stack_samples,
Expand Down Expand Up @@ -192,6 +193,7 @@
"VisionClassificationDataset",
# Utilities
"BoundingBox",
"PredictDataset",
"concat_samples",
"merge_samples",
"stack_samples",
Expand Down
104 changes: 104 additions & 0 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime, timedelta
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
Expand All @@ -30,7 +31,11 @@
import numpy as np
import rasterio
import torch
import torchvision.transforms as T
from einops import rearrange
from kornia.contrib import compute_padding, extract_tensor_patches
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.utils import draw_segmentation_masks

Expand All @@ -51,9 +56,108 @@
"draw_semantic_segmentation_masks",
"rgb_to_mask",
"percentile_normalization",
"PredictDataset",
)


class PredictDataset(Dataset[Any]):
"""Prediction dataset for VisionDatasets."""

def __init__(
self,
root: str,
patch_size: Tuple[int, int],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
bands: Tuple[int, ...] = (1, 2, 3),
) -> None:
"""Initialize a new PredictDataset instance.
Args:
root: root directory where dataset can be found
patch_size: Size of patch used as input for the model.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version.
bands: bands to be used.
"""
self.root = root
self.patch_size = patch_size
self.bands = bands
# patch_sample must not be passed to PredictDataset
if transforms:
self.transforms = T.Compose([transforms, self.patch_sample])
else:
self.transforms = self.patch_sample
self.files = self._load_files(root)

def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Extract patches from single sample."""
assert sample["image"].ndim == 3
_, h, w = sample["image"].shape

padding = compute_padding((h, w), self.patch_size)
sample["original_shape"] = (h, w)
sample["patch_shape"] = self.patch_size
sample["padding"] = padding
sample["image"] = extract_tensor_patches(
sample["image"].unsqueeze(0),
self.patch_size,
self.patch_size,
padding=padding,
)
sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w")
return sample

def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing paths for each pair of image and label
"""
images = [os.path.join(root, i) for i in os.listdir(root)]
return [{"image": img} for img in images]

def __len__(self) -> int:
"""Return the number of samples in the dataset.
Returns:
length of the dataset
"""
return len(self.files)

def _load_image(self, path: str) -> Tensor:
"""Load a single image.
Args:
path: path to the image
Returns:
the image
"""
with rasterio.open(path) as img:
array = img.read(self.bands).astype(np.int32)
tensor: Tensor = torch.from_numpy(array)
return tensor

def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
file = self.files[index]
img = self._load_image(file["image"])
sample = {"image": img}
sample = self.transforms(sample)
return sample


class _rarfile:
class RarFile:
def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down

0 comments on commit 2e280a2

Please sign in to comment.