Skip to content

Commit

Permalink
debug tpu
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Jan 29, 2022
1 parent e7c3bbf commit a0f80a6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,15 @@ def _is_slurm_managing_tasks(self):
return num_slurm_tasks == total_requested_devices

def _choose_strategy(self):
if _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ):
self._strategy_flag = HorovodStrategy()

if self._accelerator_flag == "ipu":
self._strategy_flag = "ipu"
elif self._accelerator_flag == "tpu":
if self._parallel_devices and len(self._parallel_devices) > 1:
self._strategy_flag = "tpu_spawn"
else:
self._srategy_flag = SingleTPUStrategy(device=self._parallel_devices[0])
elif _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ):
self._strategy_flag = HorovodStrategy()
else:
if self._num_nodes_flag > 1:
self._strategy_flag = "ddp"
Expand Down Expand Up @@ -549,8 +548,10 @@ def _strategy_check_and_fallbacks(self):
def _init_strategy(self):
if isinstance(self._strategy_flag, str):
self.strategy = StrategyRegistry.get(self._strategy_flag)
else:
elif isinstance(self._strategy_flag, Strategy):
self.strategy = self._strategy_flag
else:
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")

def handle_horovod(self):
if self._num_nodes_flag > 1:
Expand Down

0 comments on commit a0f80a6

Please sign in to comment.