Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert part of #10279 #10376

Merged
merged 6 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,15 @@ def _setup_dataloader(
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)

dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
try:
dataloader = type(dataloader)(**dataloader_kwargs)
except TypeError:
dataloader_kwargs.pop("dataset")
dataloader = type(dataloader)(**dataloader_kwargs)
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler)

# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
return _LiteDataLoader(
dataloader=self._strategy.process_dataloader(dataloader),
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
)

dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
return _LiteDataLoader(dataloader=dataloader, device=device)

def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
Expand Down
80 changes: 28 additions & 52 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from torch.utils.data import DataLoader, DistributedSampler, Sampler

from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -192,57 +197,6 @@ def run(self):
LiteWithCustomDataLoader().run()


def test_setup_custom_dataloaders():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a custom dataloader using only a dataset as a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already test that in tests/trainer/test_data_loading.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, we should unittest the context manager to make sure the params are

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton isn't that covered by the test that @awaelchli added in #10334?

"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()

class CustomDataLoader(DataLoader):
def __init__(self, value: int = 2, *args, **kwargs):
self.value = value
super().__init__(range(value), *args, **kwargs)

dataloader = CustomDataLoader(2, batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader.value == 2
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader2(DataLoader):
def __init__(self, range, *args, **kwargs):
self.range = range
super().__init__(range, *args, **kwargs)

dataloader = CustomDataLoader2(range(2), batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader(DataLoader):
def __init__(self, value: int, *args, **kwargs):
super().__init__(range(value), *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
def run(self):
# This doesn't fail as the context manager would save all the arguments provided
# to the dataloaders.
dataloader = CustomDataLoader(2, batch_size=2)
self.setup_dataloaders(dataloader)

LiteWithCustomDataLoader().run()

with pytest.raises(
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
):
dataloader = CustomDataLoader(2, batch_size=2)
lite_dataloader = lite.setup_dataloaders(dataloader)


def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()
Expand Down Expand Up @@ -490,3 +444,25 @@ def run(self):
assert self.is_global_zero == (self.local_rank == 0)

Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()


def test_replace_dataloader_init_method():
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""

class CustomDataLoader(DataLoader):
def __init__(self, extra_argument: int, *args, **kwargs):
super().__init__(*args, **kwargs)

dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
lite = EmptyLite()
with pytest.raises(MisconfigurationException, match="extra_argument"):
dataloader = lite.setup_dataloaders(dataloader)

with _replace_dataloader_init_method():
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)

dataloader = CustomDataLoader(1, range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)