-
Notifications
You must be signed in to change notification settings - Fork 372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stitching: Access patch geo-transform in callback after predict #1407
Comments
This is definitely doable, and I think you're on the right track. I'll just add a couple things to the pot and let you keep stirring. One thing that should be clarified is whether you are trying to stitch together patches/chips taken from a single large tile/scene, or if you also want to stitch together entire scenes into a larger extent (e.g., a single map for all of North America). @calebrob6 is in favor of saving the latter for GDAL and only focusing on the former within TorchGeo, which I think I agree with, especially for memory constraint reasons.
At the moment, all GeoDatasets store R-tree bounds and load images/masks in a common CRS. This may very well change in the future: #409. I think this part is actually relatively easy to implement and should help with #409 as well. Currently we only store CRS. We should also store transform and window like you said. The tricky part is figuring out how to pass it through the datamodule to the trainer. As you noticed, we currently prevent that in For the transform and window, it's actually relatively easy since they are just matrices. We could convert back and forth from rasterio to PyTorch Tensor. But the CRS is different. I would need to look into the internals of Lightning, but I wonder if there isn't some way to keep some keys in each batch on the CPU instead of transferring them. My first thought would be to try to modify our override to only transfer some keys and skip others. If that doesn't work, maybe open a discussion on the Lightning repo and ask for advice on how to do this. One other thought while we're on the topic. Ideally, we would be able to transfer all metadata from the file to the sample, including things like driver, dtype, nodata, compression, filename, etc. This gets even more complicated with RasterDatasets where each band is stored in a separate file. The idea is to make a prediction that replicates the input data as well as possible (same driver, dtype, etc.). We should also think about where we want to save predictions to (probably not the same location as the source data). Excited to see progress along this direction! |
I have some ideas. We could pass everything we need through together with the batch by converting crs and window-transform (affine) to arrays, then back to affine matrix afterwards, as you say. For crs we store the epsg-code as scalar. Or we could store arrays for these in the dataset and do a lookup using some index that we pass with the batch instead. I will start working on this when I get my computer back from a fire in our offices 😬 |
Not all CRS map to an EPSG code. Rasterio will attempt to find one that's close enough, but it won't be exact if we do this. |
Had another look at this and wrote down my thoughts. Keeping it simple, I have a dataset consisting of one raster file product (only on crs). GridGeoSampler yields patches (windows) that are fed in batches to an object detector.
Are my assumptions correct? I would think that many real-life scenarios involve only one product during prediction. import os
from typing import Any, Dict
import lightning.pytorch as pl
import torch
from fiftyone.utils.geotiff import rasterio
from lightning import Callback, Trainer
from lightning.pytorch.utilities.apply_func import move_data_to_device
from torch import Tensor
from torchgeo.datamodules import GeoDataModule
class CustomDataModule(GeoDataModule):
def transfer_batch_to_device(
self, batch: Dict[str, Tensor], device: torch.device, dataloader_idx: int
) -> Dict[str, Tensor]:
# don't need this, unless we need to transform to a new crs
del batch["crs"]
# bbox is the extent of each sample / patch which,
# when using a single image file,
# can be used to create rasterio window transform
batch["bbox"] = torch.as_tensor([list(bbox) for bbox in batch["bbox"]])
return move_data_to_device(batch, device)
class PatchStitchingCallback(Callback):
def __init__(self, output_dir):
super().__init__()
self.output_dir = output_dir
def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
# This only works when single file in dataset
raster_filepath = trainer.datamodule.predict_sampler.hits[0].object
with rasterio.open(raster_filepath) as src:
product_transform = src.transform
for bbox, output in zip(batch["bbox"], outputs):
# bbox is the window/patch extent
window = rasterio.windows.from_bounds(*bbox[:4], product_transform)
window_transform = rasterio.windows.transform(window, product_transform)
# given that our detection model outputs pixel coordinates
# tensor([row1, col1],
# [row2, col2]])
rows, cols = output["key_points"].T
xs, ys = rasterio.transform.xy(window_transform, rows, cols)
product_rows, product_cols = rasterio.transform.rowcol(
product_transform, xs, ys
)
output["key_points_product_crs"] = torch.stack((xs, ys), dim=1)
output["key_points_product_pixels"] = torch.stack(product_rows, product_cols)
def write_on_epoch_end(self, predictions):
torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
def on_predict_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
all_preds = trainer.predict_loop.predictions
# Merge all predictions
predictions = do_non_max_suppression(all_preds["key_points_product_pixels"])
write_on_epoch_end(predictions)
# Usage:
trainer = Trainer(
callbacks=[PatchStitchingCallback(output_dir="my_dir")],
datamodule=CustomDataModule,
model=...,
) Took some inspiration from PyTorch Lightning PredictionWriter Callback. |
For a more general implementation, we'll need to support multiple files.
Correct.
I assume this satisfies #35?
Each file in the RasterDataset gets its own entry in the R-tree. We can store arbitrary objects (or tuples of objects) in the R-tree. We currently only store filename, bbox, and CRS, but we'll likely need to add transform and possibly more metadata. The hard part will be figuring out how to get this to work in Lightning where the module wants to transfer everything to the GPU.
I think this will be a problem. For the average multi-file dataset, we'll quickly exceed memory if we wait til the end of each epoch to save predictions. |
I leave it up to the model to output relative to input. If additional transformations are applied after the sampler query (before reaching predict_step), then we need to reverse these transforms. Best would be to apply these transform in the sampler such that the query is correct. This way we can trust that the bbox/query can be converted to a representative window transform.
I found that we don't need the product (file) transform in order to convert to This is equivalent window = rasterio.windows.from_bounds(*bbox[:4], product_transform)
window_transform = rasterio.windows.transform(window, product_transform) with patch_size = trainer.datamodule.patch_size
window_transform = rasterio.transform.from_bounds(*bbox[:4], patch_size, patch_size) NB! The ordering of bbox/sampler-query is not the same as what from_bounds needs. Will fix.
Seems like lightning keeps track of all predictions anyway, so memory may already be a problem in that case. But I found out that, on multi-gpu/node, |
This is a good point, and would probably be sufficient for a first pass to solve any concerns about memory limitations. I would still love to ensure that GridGeoSampler distributes each file to the same device and flushes memory once a full file is complete, but this sounds like a lot more bookkeeping. |
Yes, this would indeed be good. What do you think about the output reference?
My guess is that most use-cases predict on only one file at a time, so I'm leaning towards 1. a the first step. Ffor data fusion use-cases this may not be the case. Also depends on how the end results will be used. Some user may want vector points in WGS84 which would be easy to implement. But for seg-masks, choosing a common crs may lead to incorrect masks? Some might want to run predict on e.g. a whole Sentinel-2 datatake. These can be up to 15k km long, so I'm struggling to decide what the crs/reference frame should be for multi-file prediction. |
I would support both GeoDataset and NonGeoDataset. The former can use the same CRS as the dataset (then we won't need access to the original file) while the latter can simply save as a PNG. |
For NonGeoDatasets, are there an equivalent to GridGeoSampler that is used? I presume these images are also processed in patches, and need to be stitched. In that case we would need the stride/overlap used in order to stitch it. I think the proposed callback above would save the predictions and corresponding sample indices, and another process would collect and stitch them, similarly to GeoDatasets. |
The majority of NonGeoDatasets already consist of patches and do not require stitching. For tile-based NonGeoDatasets, we have a |
Originally posted by @adamjstewart in #560 (comment)
Relevant issues and PRs: #30, #35, #411
To stitch the prediction output on patches from the dataset we need to transform them to a common grid/reference system to be able to stitch them.
Let's say my
predict_step
s output is relative to the resolution of the sample-patch (e.g. a seg mask or boxes/points relative to the patch resolution and size). Then the transformation to the original (full) image could be performed using pyproj.Tranformer for coords, or rasterio.window.transform. These functions would also let us convert to any crs.But then we need to reference this sample's window and source image file.
My initial thought on where to implement this was in Lightning Callback on_predict_batch_end. It has access to the trainer, and thus the dataset. Then, if we had a way to access the patch/window_transform and the source file, we could do this.
To my understanding, GeoDataModule.transfer_batch_to_device removes the reference to the patch/window due to it not being a tensor. For regular datasets, we can pass the index through. Would we need something similar here?
Thoughts? Are there any utilities for RasterDataset that I have yet to find that can help solve this?
The text was updated successfully, but these errors were encountered: