Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GeoDataset with pytorch lightning datamodule not compatible with BoundingBox in sample #1056

Closed
nilsleh opened this issue Jan 27, 2023 · 6 comments
Labels
datamodules PyTorch Lightning datamodules

Comments

@nilsleh
Copy link
Collaborator

nilsleh commented Jan 27, 2023

Description

I am using a custom RasterDataset with a LightningDataModule. However, pytorch-lightning does not like a BoundingBox object being returned as part of the sample as it cannot move it to device:

raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: A frozen dataclass was passed to `apply_to_collection` but this is not allowed. HINT: is your batch a frozen dataclass?

If I change the code to turn the BoundingBox into a list, everything works.

Steps to reproduce

import os
import tempfile

from torch.utils.data import DataLoader

from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler
from torchgeo.trainers import SemanticSegmentationTask
import pytorch_lightning as pl

data_root = tempfile.gettempdir()
naip_root = os.path.join(data_root, "naip")
naip_url = (
    "https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/"
)
tiles = [
    "m_3807511_ne_18_060_20181104.tif",
    "m_3807511_se_18_060_20181104.tif",
    "m_3807512_nw_18_060_20180815.tif",
    "m_3807512_sw_18_060_20180815.tif",
]
for tile in tiles:
    download_url(naip_url + tile, naip_root)


class MyDataModule(pl.LightningDataModule):
    def __init__(self, data_root, naip_root):
        super().__init__()
        self.naip = NAIP(naip_root)
        chesapeake_root = os.path.join(data_root, "chesapeake")
        self.chesapeake = ChesapeakeDE(chesapeake_root, crs=self.naip.crs, res=self.naip.res, download=True)
        
    def train_dataloader(self):
        dataset = self.naip & self.chesapeake
        sampler = RandomGeoSampler(dataset, size=1000, length=10)
        return DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)

    def val_dataloader(self):
        dataset = self.naip & self.chesapeake
        sampler = RandomGeoSampler(dataset, size=1000, length=10)
        return DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)

    def test_dataloader(self):
        dataset = self.naip & self.chesapeake
        sampler = RandomGeoSampler(dataset, size=1000, length=10)
        return DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)

    def on_after_batch_transfer(self, batch, dataloader_idx: int):
        # some transforms here
        return batch

task = SemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights=None,
    in_channels=4,
    num_classes=13,
    loss="jaccard",
    learning_rate=0.01,
    ignore_index=0,
    learning_rate_schedule_patience=1
)

datamodule = MyDataModule(data_root, naip_root)

trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(task, datamodule)

Version

0.4.0

@nilsleh nilsleh changed the title GeoDataset with pytorch lightning datamodule GeoDataset with pytorch lightning datamodule not compatible with BoundingBox in sample Jan 27, 2023
@adamjstewart
Copy link
Collaborator

Solution is to subclass from GeoDataModule, which has the following code to solve this issue: https://github.com/microsoft/torchgeo/blob/v0.4.0/torchgeo/datamodules/geo.py#L280,L282

@adamjstewart adamjstewart added the datamodules PyTorch Lightning datamodules label Jan 27, 2023
@nilsleh
Copy link
Collaborator Author

nilsleh commented Jan 28, 2023

Oh, sorry wasn't aware of that. Looking at that should we change bbox and crs to be represented in a compatible datatype instead of deleting it? I think that is still valuable information that people might want to use during their pipeline for visualizations etc

@adamjstewart
Copy link
Collaborator

CRS might be tricky. Bbox might be possible but I would rather wait until someone has a use case for it (probably stitching together prediction patches). It might be as simple as implementing __len__.

@adriantre
Copy link
Contributor

I have a use case for it: #1407

Our models (object detection) outputs pixel coordinates relative to the input patch. Then I need to convert them into the full image pixel coords for non-max-suppression. Then I convert them to WGS-84 for delivery to end users.

The same workflow would likely apply to segmentation.

So I need a way to access the patch-transforms after prediction.

@adamjstewart
Copy link
Collaborator

Let's continue this discussion in #1407. @nilsleh should we close this issue or keep it open until we have better documentation/error messages?

@nilsleh
Copy link
Collaborator Author

nilsleh commented Jun 9, 2023

We can close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules
Projects
None yet
Development

No branches or pull requests

3 participants