Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid patching LightningModule methods during training #6030

Closed
awaelchli opened this issue Feb 17, 2021 · 15 comments · Fixed by #9764
Closed

Avoid patching LightningModule methods during training #6030

awaelchli opened this issue Feb 17, 2021 · 15 comments · Fixed by #9764
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement refactor
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Feb 17, 2021

🚀 Feature

Can we implement the dataloaders without 🐒-patching the methods in LightningModule?

Motivation

Currently, we patch the LightningModule methods in the trainer when also a DataModule is used.
https://github.com/PyTorchLightning/pytorch-lightning/blob/5157ba55095a6a9f93ec1976aac877c87b00158f/pytorch_lightning/trainer/connectors/data_connector.py#L115

A datamodule's dataloader methods have precedence over the once defined in the LightningModule, but the LightningModule code should not be altered. The user does not know that this happens, and after training is complete, the user may wishes to continue using the model instance.

Pitch

Store the dataloader references in the trainer (or data connector) directly, without "attaching" them to the user's model.
This would also enable typing inference as mentioned by @gianscarpe.

Alternatives

Keep as is, but user will not be happy.
It's also harder to debug the way it is right now.

@awaelchli awaelchli added feature Is an improvement or enhancement help wanted Open to be worked on refactor labels Feb 17, 2021
@gianscarpe
Copy link
Contributor

Hi @awaelchli, I could work on this, as I already explores datamodules for some issue I got in my own project releated to before_batch_transfer and mypy type checking :)

@awaelchli
Copy link
Contributor Author

Aha!! Awesome, you are welcome to work on this. Ping me if you encounter any troubles along the way.
cc @PyTorchLightning/core-contributors :)

@gianscarpe
Copy link
Contributor

Hi @awaelchli, I'm working on the issue and I opened a draft PR #6103 . I have some questions:

  • If I pass a datamodule with only train dataloaders implemented, while I have a model with val and test dataloaders implemented, do we "combine" them? From tests it seems that the answer should be "yes", since for example predict mode is tested using ClassifierModel with train_dataloader implemented and passing a datamodule with only predict_datamodule
  • If I pass a datamodule with trainer_dataloader and I have a model with the same function implemented, which one should be considered? I believe datamodule implementation

@awaelchli
Copy link
Contributor Author

If I pass a datamodule with only train dataloaders implemented, while I have a model with val and test dataloaders implemented, do we "combine" them?

Yes it looks like combining is the current behavior. Makes sense. Even better would be to log which ones are used.

If I pass a datamodule with trainer_dataloader and I have a model with the same function implemented, which one should be considered? I believe datamodule implementation

datamodule and dataloaders passed to fit have precedence over the methods defined in the model.

@carmocca carmocca added this to the 1.3 milestone Feb 21, 2021
@carmocca carmocca modified the milestones: v1.3, v1.4 Apr 26, 2021
@ananthsub
Copy link
Contributor

@awaelchli @carmocca why do we need to bind the dataloader functions to the model at all? why can't we set these as attributes of the trainer? and then delete corresponding at the end of the call? or just call the right one of model/datamodule/dataloader inside of the training/evaluation/predict loop?

@awaelchli
Copy link
Contributor Author

why do we need to bind the dataloader functions to the model at all?

I don't know why. It was there from the start. I don't see a reason why it should be necessary.

why can't we set these as attributes of the trainer?
we can, and it's exactly what we propose in this issue.

If @gianscarpe doesn't have time to do it, I can try to get this started in 1.4. It will be a relatively wide-spreading refactor.

@kaushikb11
Copy link
Contributor

@awaelchli Adding you as an assignee as well! 🚀

@gianscarpe
Copy link
Contributor

Hei @awaelchli, I just forgot about this PR. I started working on the thing, if you want to we can work together :)

@awaelchli
Copy link
Contributor Author

Yes please, that would be awesome. Feel free to kick it off and I will be happy to help finish it and I can also help with failing tests.

@justusschock
Copy link
Member

@justusschock do you think it is possible to avoid patching all together? There are good reasons not to do it #6030 from a user's perspective. While I believe your PR solves the major issue, there could still be problems when the user wants to call their dataloader method to produce a fresh dataloader, for example in a callback (nothing is stopping them from doing so).

Some work has started here #7522

Originally posted by @awaelchli in #8885 (comment)

@justusschock
Copy link
Member

@awaelchli @justusschock I think we need a centralization point for where dataloaders come from. Patching the dataloader methods onto the model is using the model as the source of truth, but the side effects are visible to the end user. Rather, we could create an internal DataHolder that could be used to pool the object that has the DataHooks available. This would also codify the priority/precedence across datamodules, lightning module, and dataloaders passed directly to the trainer.

It will also raise the question: what happens when both the datamodule and the lightning module have these hooks implemented? https://github.com/PyTorchLightning/pytorch-lightning/blob/e0605472306d6b95bf2616ab88f8c29f4498402e/pytorch_lightning/core/hooks.py#L455-L807

do we raise an exception? do we run only the datamodule's, if available? do we run both and if so, in what order?

cc @ninginthecloud as another data area we should explore

Originally posted by @ananthsub in #8885 (comment)

@justusschock
Copy link
Member

@awaelchli I agree with @ananthsub that this is definitely possible and should be done!

Rather, we could create an internal DataHolder that could be used to pool the object that has the DataHooks available. This would also codify the priority/precedence across datamodules, lightning module, and dataloaders passed directly to the trainer.

I like that idea. Should this be separate the from the DataConnector?

It will also raise the question: what happens when both the datamodule and the lightning module have these hooks implemented?
do we raise an exception? do we run only the datamodule's, if available? do we run both and if so, in what order?

Personally, I would give priority to whatever was passed explicitly. So if the module has a loader implemented but a loader, datamodule with the loader is passed to the entrypoint I would only use that one. i wouldn't raise an exception but a warning.

One point is, that we don't know the precedence otherwise. In validation it would be fine to chain them, but in training it isn't that easy and I think we should be consistent here between training and validation.

Another point is, that this is also what's the current behaviour. I.e. by patching the model we also ignore the loaders from the model when given a loader explicitly.

@ananthsub
Copy link
Contributor

fyi @ninginthecloud

@zzzwen
Copy link

zzzwen commented Sep 8, 2021

One more data point, this patching "solution" is also making test mocking hard.

Say I create a mock data module.

mock_data_module = MagicMock(
    spec=pl.LightningDataModule,
    wraps=pl.LightningDataModule,
)

This line will fail https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/connectors/data_connector.py#L225

because parent and instance code is the same.
https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/utilities/model_helpers.py#L70-L71

Therefore train_dataloader will not be attached to lightning_module
and it will fail the validator

https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L938

@tchaton tchaton added the let's do it! approved to implement label Sep 10, 2021
@awaelchli
Copy link
Contributor Author

@zzzwen I completely agree. If we had the possibility to mock these methods it would simplify and harden a bunch of dataloader tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement refactor
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants