Skip to content

Commit

Permalink
1/n Move Accelerator into strategy - move batch_to_device to strategy (
Browse files Browse the repository at this point in the history
…#10649)

* 1/n Integrate Device Specific Accelerator Logic with strategy - move batch_to_device to strategy

* add changelog

* add model is not none check

* Apply suggestions from code review

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>

* Update CHANGELOG.md

* Update test_datamodules.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_hooks.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dp.py

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 29, 2021
1 parent 753cc4d commit 8bf7f9c
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 44 deletions.
5 changes: 1 addition & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541))


-


-
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))


-
Expand Down
20 changes: 1 addition & 19 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
Expand Down Expand Up @@ -145,24 +145,6 @@ def teardown(self) -> None:
"""
self.training_type_plugin.teardown()

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
"""Moves the batch to the correct device. The returned batch is of the same type as the input batch, just
having all tensors on the correct device.
Args:
batch: The batch of samples to move to the correct device
device: The target device
dataloader_idx: The index of the dataloader to which the batch belongs.
"""
model = self.lightning_module
device = device or self.root_device

if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
# no need to transfer batch to device in DP mode
return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)

return move_data_to_device(batch, device)

def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
"""The actual training step.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def advance(

if not self.trainer._data_connector.evaluation_data_fetcher.store_on_device:
with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx)

self.batch_progress.increment_ready()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def advance(
raise StopIteration

with self.trainer.profiler.profile("predict_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx)

self.batch_progress.increment_ready()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

if not self.trainer._data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)
batch = self.trainer.training_type_plugin.batch_to_device(batch)

self.batch_progress.increment_ready()

Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import Any, List, Optional

import torch
from torch.nn import DataParallel, Module
Expand All @@ -20,7 +20,7 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
Expand Down Expand Up @@ -66,6 +66,18 @@ def setup(self) -> None:
self.model_to_device()
self._model = self._setup_model(LightningParallelModule(self._model))

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
"""Moves the batch to the correct device.
The input and the output is the same type.
Args:
batch: The batch of samples to move to the correct device
device: The target device
dataloader_idx: The index of the dataloader to which the batch belongs.
"""
return move_data_to_device(batch, device=device or self.root_device)

def _setup_model(self, model: Module) -> DataParallel:
"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""
return DataParallel(module=model, device_ids=self.parallel_devices)
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT

Expand Down Expand Up @@ -90,6 +91,23 @@ def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
return optimizer

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
"""Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just
having all tensors on the correct device.
Args:
batch: The batch of samples to move to the correct device
device: The target device
dataloader_idx: The index of the dataloader to which the batch belongs.
"""
model = self.lightning_module
device = device or self.root_device
if model is not None:
return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
return move_data_to_device(batch, device)

@property
@abstractmethod
def on_gpu(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0)
data_fetcher.setup(
dataloader,
stage=stage,
batch_to_device=partial(self.trainer.accelerator.batch_to_device, dataloader_idx=dataloader_idx),
batch_to_device=partial(self.trainer.training_type_plugin.batch_to_device, dataloader_idx=dataloader_idx),
profiler=self.trainer.profiler,
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
Expand Down
7 changes: 5 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,10 @@ def test_full_loop(tmpdir):


@RunIf(min_gpus=1)
@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock)
@mock.patch(
"pytorch_lightning.plugins.training_type.training_type_plugin.TrainingTypePlugin.lightning_module",
new_callable=PropertyMock,
)
def test_dm_apply_batch_transfer_handler(get_module_mock):
expected_device = torch.device("cuda", 0)

Expand Down Expand Up @@ -306,7 +309,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
model.transfer_batch_to_device = dm.transfer_batch_to_device
model.on_after_batch_transfer = dm.on_after_batch_transfer

batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device)

assert dm.on_before_batch_transfer_hook_rank == 0
assert dm.transfer_batch_to_device_hook_rank == 1
Expand Down
22 changes: 11 additions & 11 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,35 +252,35 @@ def test_single_gpu_batch_parse():
# non-transferrable types
primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}]
for batch in primitive_objects:
data = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
data = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert data == batch

# batch is just a tensor
batch = torch.rand(2, 3)
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert batch.device.index == 0 and batch.type() == "torch.cuda.FloatTensor"

# tensor list
batch = [torch.rand(2, 3), torch.rand(2, 3)]
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert batch[0].device.index == 0 and batch[0].type() == "torch.cuda.FloatTensor"
assert batch[1].device.index == 0 and batch[1].type() == "torch.cuda.FloatTensor"

# tensor list of lists
batch = [[torch.rand(2, 3), torch.rand(2, 3)]]
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor"
assert batch[0][1].device.index == 0 and batch[0][1].type() == "torch.cuda.FloatTensor"

# tensor dict
batch = [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)}]
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert batch[0]["a"].device.index == 0 and batch[0]["a"].type() == "torch.cuda.FloatTensor"
assert batch[0]["b"].device.index == 0 and batch[0]["b"].type() == "torch.cuda.FloatTensor"

# tuple of tensor list and list of tensor dict
batch = ([torch.rand(2, 3) for _ in range(2)], [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)} for _ in range(2)])
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor"

assert batch[1][0]["a"].device.index == 0
Expand All @@ -292,7 +292,7 @@ def test_single_gpu_batch_parse():
# namedtuple of tensor
BatchType = namedtuple("BatchType", ["a", "b"])
batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)]
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
assert batch[0].a.device.index == 0
assert batch[0].a.type() == "torch.cuda.FloatTensor"

Expand All @@ -305,7 +305,7 @@ def to(self, *args, **kwargs):
self.a = self.a.to(*args, **kwargs)
return self

batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(CustomBatchType(), torch.device("cuda:0"))
assert batch.a.type() == "torch.cuda.FloatTensor"

# torchtext.data.Batch
Expand All @@ -326,7 +326,7 @@ def to(self, *args, **kwargs):
label_field.build_vocab(dataset)

batch = Batch(data=examples, dataset=dataset)
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))

assert batch.text.type() == "torch.cuda.LongTensor"
assert batch.label.type() == "torch.cuda.LongTensor"
Expand All @@ -339,7 +339,7 @@ def test_non_blocking():

batch = torch.zeros(2, 3)
with patch.object(batch, "to", wraps=batch.to) as mocked:
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
mocked.assert_called_with(torch.device("cuda", 0), non_blocking=True)

class BatchObject:
Expand All @@ -348,5 +348,5 @@ def to(self, *args, **kwargs):

batch = BatchObject()
with patch.object(batch, "to", wraps=batch.to) as mocked:
batch = trainer.accelerator.batch_to_device(batch, torch.device("cuda:0"))
batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0"))
mocked.assert_called_with(torch.device("cuda", 0))
7 changes: 5 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def on_train_batch_end(self, outputs, batch, batch_idx):


@RunIf(min_gpus=1)
@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock)
@mock.patch(
"pytorch_lightning.plugins.training_type.training_type_plugin.TrainingTypePlugin.lightning_module",
new_callable=PropertyMock,
)
def test_apply_batch_transfer_handler(model_getter_mock):
expected_device = torch.device("cuda", 0)

Expand Down Expand Up @@ -157,7 +160,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
# running .fit() would require us to implement custom data loaders, we mock the model reference instead

model_getter_mock.return_value = model
batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device)

assert model.on_before_batch_transfer_hook_rank == 0
assert model.transfer_batch_to_device_hook_rank == 1
Expand Down

0 comments on commit 8bf7f9c

Please sign in to comment.