Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InriaAerialImageLabelingDataModule: fix predict dimensions #975

Merged
merged 2 commits into from
Dec 26, 2022

Conversation

adamjstewart
Copy link
Collaborator

Previously, the following code would fail due to an extra dimension in predict samples:

datamodule = InriaAerialImageLabelingDataModule(...)
model = SemanticSegmentationTask(...)
trainer = Trainer(...)

trainer.fit(model=model, datamodule=datamodule)
trainer.predict(model=model, datamodule=datamodule)

This is now fixed and properly tested with a real trainer. I also removed the inria datamodule tests since they aren't useful.

@adamjstewart adamjstewart added this to the 0.3.2 milestone Dec 24, 2022
@github-actions github-actions bot added datamodules PyTorch Lightning datamodules testing Continuous integration testing labels Dec 24, 2022
Comment on lines -83 to +87
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())

if hasattr(datamodule, "predict_dataset"):
trainer.predict(model=model, datamodule=datamodule)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason we run predict on the val set instead of the predict set is because not all datamodules have a predict set. In order to get 100% coverage, we ran everything on val instead. However, at least for segmentation, we do have a predict set, so we should use it. This is how I discovered the bug to begin with.

@adamjstewart
Copy link
Collaborator Author

@ashnair1 this likely breaks your code in #560. I can't think of an easy way to support both "b c h w" and "b n c h w" in forward, but if you can think of one let me know. Otherwise, you'll need to "b c h w -> b n c h w" manually before running CombineTensorPatches.

@ashnair1
Copy link
Collaborator

@ashnair1 this likely breaks your code in #560. I can't think of an easy way to support both "b c h w" and "b n c h w" in forward, but if you can think of one let me know. Otherwise, you'll need to "b c h w -> b n c h w" manually before running CombineTensorPatches.

Can't think of a way to support both. Adding the number of patches to the sample dict as suggested in my comment will allow reshaping prior to CombineTensorPatches. That should work for now.

@adamjstewart adamjstewart merged commit f2d3115 into main Dec 26, 2022
@adamjstewart adamjstewart deleted the datamodules/inria branch December 26, 2022 16:20
@adamjstewart adamjstewart modified the milestones: 0.3.2, 0.4.0 Jan 23, 2023
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
…#975)

* InriaAerialImageLabelingDataModule: fix predict dimensions

* Record number of patches for reconstruction
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants