Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningDataModule: GPU data augmentation support #10469

Closed
adamjstewart opened this issue Nov 10, 2021 · 4 comments
Closed

LightningDataModule: GPU data augmentation support #10469

adamjstewart opened this issue Nov 10, 2021 · 4 comments
Labels
feature Is an improvement or enhancement

Comments

@adamjstewart
Copy link
Contributor

🚀 Feature

Data augmentation libraries like Kornia support computation directly on the GPU, greatly speeding up the rate at which images can be sampled. I would like to be able to perform these kinds of GPU transforms in a LightningDataModule.

Motivation

In TorchGeo, we use PyTorch Lightning to organize reproducible benchmarks for geospatial datasets. Currently, we have a set of LightningDataModules for each dataset and a much smaller number of LightningModules for each task (semantic segmentation, classification, regression, etc.). However, the LightningDataModule doesn't seem to know anything about the GPU, and the LightningModule doesn't seem to know anything about the LightningDataModule. Because of this, if we want to perform dataset-specific augmentations on the GPU, we're forced to create a separate LightningModule for each dataset, increasing code duplication and defeating the whole purpose of PyTorch Lightning.

Pitch

The purpose of a LightningDataModule is to handle all dataset-specific loading and augmentation so that a generic LightningModule can handle the actual training and evaluation. However, in order to take advantage of GPU-accelerated libraries like Kornia, we're currently forced to move this logic to a LightningModule. As datasets continue to increase in size, direct support for GPU-accelerated transforms in LightningDataModules will increase in importance.

Alternatives

So far the only alternative we've found is to create a different LightningModule for each dataset and include the data augmentation there. If there's a better alternative to this we would love to know about it!

@calebrob6 @isaaccorley

@tchaton
Copy link
Contributor

tchaton commented Nov 11, 2021

Hey @adamjstewart,

Awesome work on TorchGeo !

You can actually easily enable GPU transform from your DataModule.

Check this DataModule in Lightning Flash: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/data/new_data_module.py#L240.

The DataModule implements an on_after_batch_transfer hook which will be attached to the Model during training and applied right after the batch has been transferred to the device.

In Flash, the dataset owns their transform, and we extract the on_after_batch_transfer_fn directly from the InputTransform.

Have a look at our tutorial: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/flash_components/custom_data_loading.py

Best,
T.C

@adamjstewart
Copy link
Contributor Author

Thanks @tchaton, I knew there had to be some way to do it! We'll try this out and let you know if we hit any snags.

@adamjstewart
Copy link
Contributor Author

@tchaton the docs for on_after_batch_transfer mention using self.trainer.training but I can't see where this attribute is defined. Is this a boolean? Also, one of the links that you shared before seems to be dead now.

@tchaton
Copy link
Contributor

tchaton commented Dec 20, 2021

Hey @adamjstewart,

Yes, it is a boolean. It is defined there on the Trainer: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L2028.

By the way, we refactored entirely Lightning Flash Data API and I believe you might want to have a look to better organize your own library.

Best,
T.C

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

No branches or pull requests

2 participants