Skip to content

Commit

Permalink
as discussed, datamodule and dataloaders (from model) are merged. Dat…
Browse files Browse the repository at this point in the history
…amodule has the precedence
  • Loading branch information
gianscarpe committed Feb 21, 2021
1 parent f96d871 commit de60a3b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
10 changes: 6 additions & 4 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def verify_loop_configurations(self, model: LightningModule):
self.__verify_eval_loop_configuration(model, 'test')

def __verify_train_dataloader(self, model):
# We COMBINE dataloaders and datamodule. Datamodule has precedence

has_train_dataloader = False
if self.trainer.datamodule:
has_train_dataloader = is_overridden(
'train_dataloader', self.trainer.datamodule
)
else:
has_train_dataloader = is_overridden('train_dataloader', self.trainer.datamodule)

if not has_train_dataloader:
has_train_dataloader = is_overridden('train_dataloader', model)

if not has_train_dataloader:
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.supporters import CombinedLoader
Expand Down Expand Up @@ -196,8 +197,20 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""

# TODO - what if we pass `train_dataloader` as input, instead of model?
# TODO - what if we pass `train_dataloader` as input, in
model_has_loader = is_overridden('train_dataloader', model)
datamodule_has_loader = False

if self.datamodule:
datamodule_has_loader = is_overridden('train_dataloader', model)

if datamodule_has_loader and model_has_loader:
log.info(
"You implemented `train_dataloader` both in model and datamodule. Please note that datamodule implementation has precedence"
)

# Configuration_validation has already checked that either datamodule or model has overridden train_dataloader
if datamodule_has_loader:
train_dataloader = self.datamodule.train_dataloader
else:
train_dataloader = model.train_dataloader
Expand Down

0 comments on commit de60a3b

Please sign in to comment.