Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem training with standard dataloader. #1426

Open
lcoandrade opened this issue Jun 19, 2023 · 5 comments
Open

Problem training with standard dataloader. #1426

lcoandrade opened this issue Jun 19, 2023 · 5 comments
Labels
trainers PyTorch Lightning trainers

Comments

@lcoandrade
Copy link

lcoandrade commented Jun 19, 2023

Description

I've just learnt about Torchgeo and got interested in using it. So, I created a Kaggle notebook to test it with NAIP and Chesapeake data (Torchgeo 101).
When I try to train a segmentation task, I get the following error:
ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.

Steps to reproduce

  1. Create a dataset with NAIP and Chesapeake data:
# Creating the NAIP dataset
naip_root = os.path.join(INPUT_DIR, 'naip')
naip = NAIP(naip_root)

# Creating the CHESAPEAKE dataset
chesapeake_root = os.path.join(INPUT_DIR, "chesapeake")
chesapeake = ChesapeakeDE(
    chesapeake_root, 
    crs=naip.crs, 
    res=naip.res, 
    download=False
)
  1. Make an intersection, create a sampler and a dataloader:
dataset = naip & chesapeake
sampler = RandomGeoSampler(dataset, size=IMG_SIZE, length=SAMPLE_SIZE)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
  1. Define a trainer:
DEVICE, NUM_DEVICES = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count())
WORKERS = mp.cpu_count()
print(f'Running on {NUM_DEVICES} {DEVICE}(s)')

trainer = pl.Trainer(
        accelerator=DEVICE,
        devices=NUM_DEVICES,
        max_epochs=EPOCHS,
        callbacks=[checkpoint_callback, ],
        logger=logger,
    )
  1. Define a segmentation task:
ssl._create_default_https_context = ssl._create_unverified_context

test_dir = os.path.join(OUTPUT_DIR, "test")
if not os.path.exists(test_dir):
    os.makedirs(test_dir)
    
logger = CSVLogger(
    test_dir, 
    name='torchgeo_logs'
)

checkpoint_callback = ModelCheckpoint(
    every_n_epochs=1,
    dirpath=test_dir,
    filename='torchgeo_trained'
)

task = SemanticSegmentationTask(
    model = SEGMENTATION_MODEL,
    backbone = BACKBONE,
    weights = WEIGHTS,
    in_channels = IN_CHANNELS,
    num_classes = NUM_CLASSES,
    loss = LOSS,
    ignore_index = None,
    learning_rate = LR,
    learning_rate_schedule_patience = PATIENCE, 
)
  1. Start training:
trainer.fit(
        model=task, 
        train_dataloaders=dataloader,
    )

Version

0.4.1

@adamjstewart
Copy link
Collaborator

Duplicate of #1056 and #1418

The issue is that some of the sample values returned by GeoDataset can't be automatically collated by PyTorch (BoundingBox, CRS). Our solution for our builtin data modules is to remove these values before loading: https://github.com/microsoft/torchgeo/blob/v0.4.1/torchgeo/datamodules/geo.py#L280

My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...

@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label Jun 19, 2023
@adamjstewart
Copy link
Collaborator

Is this still an issue or can this be closed?

@lcoandrade
Copy link
Author

I've made a CustomGeoDatamodule like this:

class CustomGeoDataModule(GeoDataModule):
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)
        
        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)
        
        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

To solve my problem.

@trchudley
Copy link

trchudley commented Nov 6, 2023

My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...

Hi @adamjstewart

I've also encountered this problem, and it's taken me a while to find the solution. Definitely +1 for adding this as a feature of torchgeo to make this as seamless as possible for the end-users using GeoDatasets.

Cheers,
Tom

@adamjstewart adamjstewart reopened this Nov 6, 2023
@adamjstewart
Copy link
Collaborator

Reopening as a reminder to try to upstream some of our changes to PyTorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

No branches or pull requests

3 participants