Skip to content

Commit

Permalink
Fix TPU pod handling
Browse files Browse the repository at this point in the history
  • Loading branch information
iojw committed Dec 1, 2022
1 parent b025c01 commit 6782ab5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from sky.utils import command_runner
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import tpu_utils
from sky.utils import ux_utils
from sky.utils import validator
from sky.usage import usage_lib
Expand Down Expand Up @@ -1154,8 +1155,9 @@ def get_node_ips(cluster_yaml: str,
if use_tpu_vm:
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
assert expected_num_nodes == 1, 'TPU VM only supports single node for now.'
if len(ips) != expected_num_nodes:
if len(ips) != tpu_utils.get_num_tpu_devices(handle.launched_resources):
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
return ips

if get_internal_ips:
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
Expand Down
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ def _tpu_pod_setup(self, cluster_yaml: str,
run setup or launch ray cluster on TPU VM Pod nodes.
"""
ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml)
all_ips = cluster_handle.external_ips()
all_ips = cluster_handle.external_ips(use_cached_ips=False)
num_tpu_devices = tpu_utils.get_num_tpu_devices(
cluster_handle.launched_resources)
if len(all_ips) != num_tpu_devices:
Expand Down

0 comments on commit 6782ab5

Please sign in to comment.