How to access LightningDataModule
in LightningModule
#10492
-
In TorchGeo, we use PyTorch Lightning to organize reproducible benchmarks for geospatial datasets. Currently, we have a set of LightningDataModules for each dataset and a much smaller number of LightningModules for each task (semantic segmentation, classification, regression, etc.). Each Dataset defines its own During training/validation steps, we would like to plot a few examples to see how training is progressing. However, the LightningModule doesn't seem to know anything about the LightningDataModule/DataLoader/Dataset. Because of this, if we want to perform dataset-specific plotting during training or validation steps, we're forced to create a separate LightningModule for each dataset, increasing code duplication and defeating the whole purpose of PyTorch Lightning (example). Is there an easy way for a LightningModule to tell which DataModule/DataLoader/Dataset is being used and call its @tchaton this is slightly related to #10469 but different enough that I wanted to start a separate discussion about it. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
@adamjstewart There is a reference to datamodule via trainer from LightningModule, but would that solve your issue? self.trainer.datamodule |
Beta Was this translation helpful? Give feedback.
-
This dependence sounds like the data isn't as separable from the model/loop. Relying on self.trainer.datamodule is not foolproof. Someone could use your lightning module but pass the data loaders directly to the trainer.fit function. In this case, there is no datamodule provided, and the module could fail unless it checks against this |
Beta Was this translation helpful? Give feedback.
@adamjstewart There is a reference to datamodule via trainer from LightningModule, but would that solve your issue?