-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LightningDataModule: GPU data augmentation support #10469
Comments
Hey @adamjstewart, Awesome work on TorchGeo ! You can actually easily enable GPU transform from your DataModule. Check this DataModule in Lightning Flash: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/data/new_data_module.py#L240. The DataModule implements an In Flash, the dataset owns their transform, and we extract the Have a look at our tutorial: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/flash_components/custom_data_loading.py Best, |
Thanks @tchaton, I knew there had to be some way to do it! We'll try this out and let you know if we hit any snags. |
@tchaton the docs for on_after_batch_transfer mention using |
Hey @adamjstewart, Yes, it is a boolean. It is defined there on the Trainer: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L2028. By the way, we refactored entirely Lightning Flash Data API and I believe you might want to have a look to better organize your own library. Best, |
🚀 Feature
Data augmentation libraries like Kornia support computation directly on the GPU, greatly speeding up the rate at which images can be sampled. I would like to be able to perform these kinds of GPU transforms in a LightningDataModule.
Motivation
In TorchGeo, we use PyTorch Lightning to organize reproducible benchmarks for geospatial datasets. Currently, we have a set of LightningDataModules for each dataset and a much smaller number of LightningModules for each task (semantic segmentation, classification, regression, etc.). However, the LightningDataModule doesn't seem to know anything about the GPU, and the LightningModule doesn't seem to know anything about the LightningDataModule. Because of this, if we want to perform dataset-specific augmentations on the GPU, we're forced to create a separate LightningModule for each dataset, increasing code duplication and defeating the whole purpose of PyTorch Lightning.
Pitch
The purpose of a LightningDataModule is to handle all dataset-specific loading and augmentation so that a generic LightningModule can handle the actual training and evaluation. However, in order to take advantage of GPU-accelerated libraries like Kornia, we're currently forced to move this logic to a LightningModule. As datasets continue to increase in size, direct support for GPU-accelerated transforms in LightningDataModules will increase in importance.
Alternatives
So far the only alternative we've found is to create a different LightningModule for each dataset and include the data augmentation there. If there's a better alternative to this we would love to know about it!
@calebrob6 @isaaccorley
The text was updated successfully, but these errors were encountered: