From 8bf7f9cce744bbc05ebce03bfd0567e3bc414128 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Mon, 29 Nov 2021 12:11:21 -0800 Subject: [PATCH] 1/n Move Accelerator into strategy - move batch_to_device to strategy (#10649) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 1/n Integrate Device Specific Accelerator Logic with strategy - move batch_to_device to strategy * add changelog * add model is not none check * Apply suggestions from code review Co-authored-by: thomas chaton Co-authored-by: Carlos MocholĂ­ * Update CHANGELOG.md * Update test_datamodules.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_hooks.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dp.py Co-authored-by: thomas chaton Co-authored-by: Carlos MocholĂ­ Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 5 +---- pytorch_lightning/accelerators/accelerator.py | 20 +---------------- .../loops/epoch/evaluation_epoch_loop.py | 2 +- .../loops/epoch/prediction_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 16 ++++++++++++-- .../training_type/training_type_plugin.py | 18 +++++++++++++++ .../trainer/connectors/data_connector.py | 2 +- tests/core/test_datamodules.py | 7 ++++-- tests/models/test_gpu.py | 22 +++++++++---------- tests/models/test_hooks.py | 7 ++++-- 11 files changed, 59 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 733c77957cbff..c2b50c1d41f5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,10 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) -- - - -- +- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) - diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index eb3886b209503..8ccc2d86edd9e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -23,7 +23,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin -from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin +from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -145,24 +145,6 @@ def teardown(self) -> None: """ self.training_type_plugin.teardown() - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: - """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just - having all tensors on the correct device. - - Args: - batch: The batch of samples to move to the correct device - device: The target device - dataloader_idx: The index of the dataloader to which the batch belongs. - """ - model = self.lightning_module - device = device or self.root_device - - if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): - # no need to transfer batch to device in DP mode - return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) - - return move_data_to_device(batch, device) - def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index cbaac51ff1d58..971796154fa77 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -110,7 +110,7 @@ def advance( if not self.trainer._data_connector.evaluation_data_fetcher.store_on_device: with self.trainer.profiler.profile("evaluation_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 58e65233dfe81..558b1052c4e50 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -92,7 +92,7 @@ def advance( raise StopIteration with self.trainer.profiler.profile("predict_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 2a471ab198d1d..d150c8b374dad 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -156,7 +156,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: if not self.trainer._data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): - batch = self.trainer.accelerator.batch_to_device(batch) + batch = self.trainer.training_type_plugin.batch_to_device(batch) self.batch_progress.increment_ready() diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 3f1b9a3acfa50..b02b4bdefaa1b 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Any, List, Optional import torch from torch.nn import DataParallel, Module @@ -20,7 +20,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -66,6 +66,18 @@ def setup(self) -> None: self.model_to_device() self._model = self._setup_model(LightningParallelModule(self._model)) + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + """Moves the batch to the correct device. + + The input and the output is the same type. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + dataloader_idx: The index of the dataloader to which the batch belongs. + """ + return move_data_to_device(batch, device=device or self.root_device) + def _setup_model(self, model: Module) -> DataParallel: """Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module.""" return DataParallel(module=model, device_ids=self.parallel_devices) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b8244b9c2e165..75ae5592a29ef 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -26,6 +26,7 @@ from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -90,6 +91,23 @@ def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return optimizer + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + """Moves the batch to the correct device. + + The returned batch is of the same type as the input batch, just + having all tensors on the correct device. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + dataloader_idx: The index of the dataloader to which the batch belongs. + """ + model = self.lightning_module + device = device or self.root_device + if model is not None: + return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) + return move_data_to_device(batch, device) + @property @abstractmethod def on_gpu(self) -> bool: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index e6f76e0403bd7..d32fbc52dc4fb 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -118,7 +118,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) data_fetcher.setup( dataloader, stage=stage, - batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx), + batch_to_device=partial(self.trainer.training_type_plugin.batch_to_device, dataloader_idx=dataloader_idx), profiler=self.trainer.profiler, ) setattr(self, f"{stage}_data_fetcher", data_fetcher) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index d35941ac2cb15..59b68a723edc1 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -253,7 +253,10 @@ def test_full_loop(tmpdir): @RunIf(min_gpus=1) -@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +@mock.patch( + "pytorch_lightning.plugins.training_type.training_type_plugin.TrainingTypePlugin.lightning_module", + new_callable=PropertyMock, +) def test_dm_apply_batch_transfer_handler(get_module_mock): expected_device = torch.device("cuda", 0) @@ -306,7 +309,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer - batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) + batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index bf9dab47a71aa..9e0e67200c38f 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -252,35 +252,35 @@ def test_single_gpu_batch_parse(): # non-transferrable types primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] for batch in primitive_objects: - data = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + data = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert data == batch # batch is just a tensor batch = torch.rand(2, 3) - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch.device.index == 0 and batch.type() == "torch.cuda.FloatTensor" # tensor list batch = [torch.rand(2, 3), torch.rand(2, 3)] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].device.index == 0 and batch[0].type() == "torch.cuda.FloatTensor" assert batch[1].device.index == 0 and batch[1].type() == "torch.cuda.FloatTensor" # tensor list of lists batch = [[torch.rand(2, 3), torch.rand(2, 3)]] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[0][1].device.index == 0 and batch[0][1].type() == "torch.cuda.FloatTensor" # tensor dict batch = [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)}] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0]["a"].device.index == 0 and batch[0]["a"].type() == "torch.cuda.FloatTensor" assert batch[0]["b"].device.index == 0 and batch[0]["b"].type() == "torch.cuda.FloatTensor" # tuple of tensor list and list of tensor dict batch = ([torch.rand(2, 3) for _ in range(2)], [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)} for _ in range(2)]) - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[1][0]["a"].device.index == 0 @@ -292,7 +292,7 @@ def test_single_gpu_batch_parse(): # namedtuple of tensor BatchType = namedtuple("BatchType", ["a", "b"]) batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].a.device.index == 0 assert batch[0].a.type() == "torch.cuda.FloatTensor" @@ -305,7 +305,7 @@ def to(self, *args, **kwargs): self.a = self.a.to(*args, **kwargs) return self - batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(CustomBatchType(), torch.device("cuda:0")) assert batch.a.type() == "torch.cuda.FloatTensor" # torchtext.data.Batch @@ -326,7 +326,7 @@ def to(self, *args, **kwargs): label_field.build_vocab(dataset) batch = Batch(data=examples, dataset=dataset) - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) assert batch.text.type() == "torch.cuda.LongTensor" assert batch.label.type() == "torch.cuda.LongTensor" @@ -339,7 +339,7 @@ def test_non_blocking(): batch = torch.zeros(2, 3) with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0), non_blocking=True) class BatchObject: @@ -348,5 +348,5 @@ def to(self, *args, **kwargs): batch = BatchObject() with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0)) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e8db816ed4edc..68dc9d2fefeb5 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -112,7 +112,10 @@ def on_train_batch_end(self, outputs, batch, batch_idx): @RunIf(min_gpus=1) -@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +@mock.patch( + "pytorch_lightning.plugins.training_type.training_type_plugin.TrainingTypePlugin.lightning_module", + new_callable=PropertyMock, +) def test_apply_batch_transfer_handler(model_getter_mock): expected_device = torch.device("cuda", 0) @@ -157,7 +160,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): # running .fit() would require us to implement custom data loaders, we mock the model reference instead model_getter_mock.return_value = model - batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) + batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) assert model.on_before_batch_transfer_hook_rank == 0 assert model.transfer_batch_to_device_hook_rank == 1