Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecation warnings being called for on_{task}_dataloader #9279

Merged
merged 5 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated properties `DeepSpeedPlugin.cpu_offload*` in favor of `offload_optimizer`, `offload_parameters` and `pin_memory` ([#9244](https://github.com/PyTorchLightning/pytorch-lightning/pull/9244))


- Removed deprecation warnings being called for `on_{task}_dataloader` ([#9279](https://github.com/PyTorchLightning/pytorch-lightning/pull/9279))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was added in #9098 so the changelog shouldn't mention it again here. it's not released



### Fixed

- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
Expand Down
17 changes: 0 additions & 17 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import rank_zero_deprecation


class ModelHooks:
Expand Down Expand Up @@ -691,10 +690,6 @@ def on_train_dataloader(self) -> None:
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`train_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `train_dataloader()` directly."
)

def on_val_dataloader(self) -> None:
"""Called before requesting the val dataloader.
Expand All @@ -703,10 +698,6 @@ def on_val_dataloader(self) -> None:
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`val_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)

def on_test_dataloader(self) -> None:
"""Called before requesting the test dataloader.
Expand All @@ -715,10 +706,6 @@ def on_test_dataloader(self) -> None:
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`test_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `test_dataloader()` directly."
)

def on_predict_dataloader(self) -> None:
"""Called before requesting the predict dataloader.
Expand All @@ -727,10 +714,6 @@ def on_predict_dataloader(self) -> None:
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`predict_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `predict_dataloader()` directly."
)

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
"""
Expand Down
37 changes: 31 additions & 6 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,50 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
_ = Trainer(prepare_data_per_node=False)


def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
print("on_train_dataloader")

def on_val_dataloader(self):
print("on_val_dataloader")

def on_test_dataloader(self):
print("on_test_dataloader")

def on_predict_dataloader(self):
print("on_predict_dataloader")

def _run(model, task="fit"):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
getattr(trainer, task)(model)

model = CustomBoringModel()

model = BoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_train_dataloader()
_run(model, "fit")

with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_val_dataloader()
_run(model, "fit")

with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
_run(model, "validate")

with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_test_dataloader()
_run(model, "test")

with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_predict_dataloader()
_run(model, "predict")


@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
Expand Down
27 changes: 0 additions & 27 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,30 +1897,3 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
) as exception_hook:
trainer.predict(model, model.val_dataloader(), return_predictions=False)
exception_hook.assert_called()


def test_overridden_on_dataloaders(tmpdir):
model = BoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)

with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.validate(model)

with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.test(model)

with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.predict(model)