diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b845421d7fd4..e52d811990a48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -152,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to explicitly specify the process group backend for parallel strategies ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745)) +- Added `device_ids` and `num_devices` property to `Trainer` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151)) + + ### Changed - Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191)) @@ -518,6 +521,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745)) +- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 999f55b4f10b7..e341d553ad5b4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2010,6 +2010,23 @@ def should_rank_save_checkpoint(self) -> bool: def num_nodes(self) -> int: return getattr(self.strategy, "num_nodes", 1) + @property + def device_ids(self) -> List[int]: + """List of device indexes per node.""" + devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device]) + device_ids = [] + for idx, device in enumerate(devices): + if isinstance(device, torch.device): + device_ids.append(device.index or idx) + elif isinstance(device, int): + device_ids.append(device) + return device_ids + + @property + def num_devices(self) -> int: + """Number of devices the trainer uses per node.""" + return len(self.device_ids) + @property def num_processes(self) -> int: return self._accelerator_connector.num_processes @@ -2031,8 +2048,12 @@ def num_gpus(self) -> int: return self._accelerator_connector.num_gpus @property - def devices(self) -> Optional[Union[List[int], str, int]]: - return self._accelerator_connector.devices + def devices(self) -> int: + rank_zero_deprecation( + "`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8." + " Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead." + ) + return self.num_devices @property def data_parallel_device_ids(self) -> Optional[List[int]]: diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 8e79ce1caa6b8..5c2502fa6119d 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -579,14 +579,14 @@ def test_validate_accelerator_and_devices(): def test_set_devices_if_none_cpu(): trainer = Trainer(accelerator="cpu", num_processes=3) - assert trainer.devices == 3 + assert trainer.num_devices == 3 @RunIf(min_gpus=2) def test_set_devices_if_none_gpu(): trainer = Trainer(accelerator="gpu", gpus=2) - assert trainer.devices == 2 + assert trainer.num_devices == 2 def test_devices_with_cpu_only_supports_integer(): @@ -594,7 +594,7 @@ def test_devices_with_cpu_only_supports_integer(): with pytest.warns(UserWarning, match="The flag `devices` must be an int"): trainer = Trainer(accelerator="cpu", devices="1,3") assert isinstance(trainer.accelerator, CPUAccelerator) - assert trainer.devices == 1 + assert trainer.num_devices == 1 @pytest.mark.parametrize("training_type", ["ddp2", "dp"]) @@ -941,7 +941,7 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): @mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False) def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock): trainer = Trainer(accelerator="auto", devices="auto") - assert trainer.devices == 1 + assert trainer.num_devices == 1 assert trainer.num_processes == 1 @@ -949,7 +949,7 @@ def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, i @mock.patch("torch.cuda.device_count", return_value=2) def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): trainer = Trainer(accelerator="auto", devices="auto") - assert trainer.devices == 2 + assert trainer.num_devices == 2 assert trainer.gpus == 2 diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 00fe86995f4c3..5a09d654bf437 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -400,7 +400,7 @@ def test_manual_poptorch_opts(tmpdir): dataloader = trainer.train_dataloader.loaders assert isinstance(dataloader, poptorch.DataLoader) assert dataloader.options == training_opts - assert trainer.devices > 1 # testing this only makes sense in a distributed setting + assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting assert not isinstance(dataloader.sampler, DistributedSampler) @@ -588,7 +588,7 @@ def test_accelerator_ipu_with_ipus_priority(): def test_set_devices_if_none_ipu(): trainer = Trainer(accelerator="ipu", ipus=8) - assert trainer.devices == 8 + assert trainer.num_devices == 8 @RunIf(ipu=True) @@ -631,5 +631,5 @@ def test_poptorch_models_at_different_stages(tmpdir): @RunIf(ipu=True) def test_devices_auto_choice_ipu(): trainer = Trainer(accelerator="auto", devices="auto") - assert trainer.devices == 4 + assert trainer.num_devices == 4 assert trainer.ipus == 4 diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 1e74cde1f70c6..5fc2aba8cbe57 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -94,6 +94,7 @@ def test_accelerator_cpu_with_tpu_cores_flag(): @RunIf(tpu=True) +@pl_multi_process_test @pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)]) def test_accelerator_tpu(accelerator, devices): assert TPUAccelerator.is_available() @@ -101,7 +102,7 @@ def test_accelerator_tpu(accelerator, devices): trainer = Trainer(accelerator=accelerator, devices=devices) assert isinstance(trainer.accelerator, TPUAccelerator) assert isinstance(trainer.strategy, TPUSpawnStrategy) - assert trainer.devices == 8 + assert trainer.num_devices == 8 assert trainer.tpu_cores == 8 @@ -117,10 +118,10 @@ def test_accelerator_tpu_with_tpu_cores_priority(): @RunIf(tpu=True) +@pl_multi_process_test def test_set_devices_if_none_tpu(): - trainer = Trainer(accelerator="tpu", tpu_cores=8) - assert trainer.devices == 8 + assert trainer.num_devices == 8 @RunIf(tpu=True) @@ -310,3 +311,21 @@ def test_mp_device_dataloader_attribute(_): def test_warning_if_tpus_not_used(): with pytest.warns(UserWarning, match="TPU available but not used. Set `accelerator` and `devices`"): Trainer() + + +@pytest.mark.skip(reason="TODO(@kaushikb11): Optimize TPU tests to avoid timeouts") +@RunIf(tpu=True) +@pytest.mark.parametrize( + ["devices", "expected_device_ids"], + [ + (1, [0]), + (8, list(range(8))), + ("8", list(range(8))), + ([2], [2]), + ("2,", [2]), + ], +) +def test_trainer_config_device_ids(devices, expected_device_ids): + trainer = Trainer(accelerator="tpu", devices=devices) + assert trainer.device_ids == expected_device_ids + assert trainer.num_devices == len(expected_device_ids) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 7acc43851bbb2..101f711619078 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -878,3 +878,12 @@ def all_gather(self, tensor): match="ParallelStrategy.torch_distributed_backend was deprecated" " in v1.6 and will be removed in v1.8." ): strategy.torch_distributed_backend + + +def test_trainer_config_device_ids(): + trainer = Trainer(devices=2) + with pytest.deprecated_call( + match="`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8." + " Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead." + ): + trainer.devices == 2 diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index 58e2b8e9cb439..e7c9a13a0cd3c 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -51,6 +51,6 @@ def test_passing_env_variables_defaults(): def test_passing_env_variables_devices(cuda_available_mock, device_count_mock): """Testing overwriting trainer arguments.""" trainer = Trainer() - assert trainer.devices == 2 + assert trainer.num_devices == 2 trainer = Trainer(accelerator="gpu", devices=1) - assert trainer.devices == 1 + assert trainer.num_devices == 1 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6f4d7300220e5..7e44d85ae7ea0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -30,6 +30,7 @@ from torch.optim import SGD from torch.utils.data import DataLoader, IterableDataset +import pytorch_lightning import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator @@ -2117,3 +2118,35 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st else getattr(trainer, f"{dl_prefix}_dataloaders") ) assert dl is None + + +@pytest.mark.parametrize( + ["trainer_kwargs", "expected_device_ids"], + [ + ({}, [0]), + ({"devices": 1}, [0]), + ({"devices": 1}, [0]), + ({"devices": "1"}, [0]), + ({"devices": 2}, [0, 1]), + ({"accelerator": "gpu", "devices": 1}, [0]), + ({"accelerator": "gpu", "devices": 2}, [0, 1]), + ({"accelerator": "gpu", "devices": "2"}, [0, 1]), + ({"accelerator": "gpu", "devices": [2]}, [2]), + ({"accelerator": "gpu", "devices": "2,"}, [2]), + ({"accelerator": "gpu", "devices": [0, 2]}, [0, 2]), + ({"accelerator": "gpu", "devices": "0, 2"}, [0, 2]), + ({"accelerator": "ipu", "devices": 1}, [0]), + ({"accelerator": "ipu", "devices": 2}, [0, 1]), + ], +) +def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids): + if trainer_kwargs.get("accelerator") == "gpu": + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) + elif trainer_kwargs.get("accelerator") == "ipu": + monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True) + monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", lambda: True) + + trainer = Trainer(**trainer_kwargs) + assert trainer.device_ids == expected_device_ids + assert trainer.num_devices == len(expected_device_ids)