diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 9626390d63521a..ecec428ff9248a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -461,6 +461,7 @@ def _set_parallel_devices_and_init_accelerator(self): ) self._gpus = self._device_flag if not self._gpus else self._gpus + self._tpu_cores = self._device_flag if not self._tpu_cores else self._tpu_cores def _choose_and_init_cluster_environment(self): self.cluster_environment = LightningEnvironment() @@ -484,7 +485,7 @@ def _is_slurm_managing_tasks(self): return num_slurm_tasks == total_requested_devices def _choose_strategy(self): - if self._accelerator_flag == "ipu_strategy": + if self._accelerator_flag == "ipu": self._strategy_flag = "ipu_strategy" elif self._accelerator_flag == "tpu": if self._parallel_devices and len(self._parallel_devices) > 1: @@ -757,15 +758,15 @@ def devices(self): return 0 @property - def tpu_cores(self) -> int: + def tpu_cores(self): if isinstance(self.accelerator, TPUAccelerator): - return self.devices + return self._tpu_cores return 0 @property def tpu_id(self) -> Optional[int]: if isinstance(self.accelerator, TPUAccelerator): - return self.parallel_devices[0] + return self.tpu_cores[0] return None @property diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 9fadb6a52cade1..9f6e4d931376ec 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -115,7 +115,7 @@ def test_accelerator_selected(tmpdir): @RunIf(ipu=True) def test_warning_if_ipus_not_used(tmpdir): with pytest.warns(UserWarning, match="IPU available but not used. Set the `ipus` flag in your trainer"): - Trainer(default_root_dir=tmpdir) + Trainer(default_root_dir=tmpdir, accelerator="cpu") @RunIf(ipu=True)