Skip to content

Commit

Permalink
Fix ResourceHandle semantics (#1481)
Browse files Browse the repository at this point in the history
* Better handling of old clusters

* Handle IP functions should always return list

* Fix head ip call

* Fix TPU pod handling

* Move assertion to start

* Update handling of IPs

* Always run tpu_vm_pod test

* Update types

* Fix linting

* Remove smoke test comment

* Update tpu pod smoke test
  • Loading branch information
iojw authored Dec 1, 2022
1 parent 5506e1d commit 1955bee
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 18 deletions.
14 changes: 9 additions & 5 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from sky.utils import command_runner
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import tpu_utils
from sky.utils import ux_utils
from sky.utils import validator
from sky.usage import usage_lib
Expand Down Expand Up @@ -1152,10 +1153,11 @@ def get_node_ips(cluster_yaml: str,
ray_config = common_utils.read_yaml(cluster_yaml)
use_tpu_vm = ray_config['provider'].get('_has_tpus', False)
if use_tpu_vm:
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
assert expected_num_nodes == 1, 'TPU VM only supports single node for now.'
if len(ips) != expected_num_nodes:
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
if len(ips) != tpu_utils.get_num_tpu_devices(handle.launched_resources):
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
return ips

if get_internal_ips:
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
Expand Down Expand Up @@ -1626,7 +1628,7 @@ def _update_cluster_status_no_lock(
# in ray's get IPs vs. ray runtime failing.
external_ips = handle.external_ips(use_cached_ips=False)
# This happens to a stopped TPU VM as we use gcloud to query the IP.
if len(external_ips) == 0:
if external_ips is None or len(external_ips) == 0:
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD)
if handle.launched_nodes == 1:
Expand Down Expand Up @@ -1694,8 +1696,10 @@ def _update_cluster_status_no_lock(
# that the cluster is partially preempted.
# TODO(zhwu): the definition of INIT should be audited/changed.
# Adding a new status UNHEALTHY for abnormal status can be a choice.
global_user_state.set_cluster_status(
cluster_name, global_user_state.ClusterStatus.INIT)
global_user_state.add_or_update_cluster(cluster_name,
handle,
ready=False,
is_launch=False)
return global_user_state.get_cluster_from_name(cluster_name)
# Now is_abnormal is False: either node_statuses is empty or all nodes are STOPPED.
backend = backends.CloudVmRayBackend()
Expand Down
28 changes: 18 additions & 10 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,12 @@ def _tpu_pod_setup(self, cluster_yaml: str,
run setup or launch ray cluster on TPU VM Pod nodes.
"""
ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml)
all_ips = cluster_handle.external_ips()
all_ips = cluster_handle.external_ips(use_cached_ips=False)
num_tpu_devices = tpu_utils.get_num_tpu_devices(
cluster_handle.launched_resources)
if len(all_ips) != num_tpu_devices:
if all_ips is None or len(all_ips) != num_tpu_devices:
raise RuntimeError(
f'Number of nodes IPs: {len(all_ips)} does not'
f'Nodes IPs: {all_ips} does not'
f'match number of TPU devices: {num_tpu_devices}.')

# Get the private IP of head node for connecting Ray cluster.
Expand Down Expand Up @@ -1709,7 +1709,7 @@ def _update_stable_cluster_ips(self,

def internal_ips(self,
max_attempts: int = 1,
use_cached_ips: bool = True):
use_cached_ips: bool = True) -> Optional[List[str]]:
if not use_cached_ips:
self._update_stable_cluster_ips(max_attempts=max_attempts)
if self.stable_internal_external_ips is not None:
Expand All @@ -1718,7 +1718,7 @@ def internal_ips(self,

def external_ips(self,
max_attempts: int = 1,
use_cached_ips: bool = True):
use_cached_ips: bool = True) -> Optional[List[str]]:
if not use_cached_ips:
self._update_stable_cluster_ips(max_attempts=max_attempts)
if self.stable_internal_external_ips is not None:
Expand Down Expand Up @@ -2046,6 +2046,7 @@ def _sync_workdir(self, handle: ResourceHandle, workdir: Path) -> None:
fore = colorama.Fore
style = colorama.Style
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
full_workdir = os.path.abspath(os.path.expanduser(workdir))

# These asserts have been validated at Task construction time.
Expand Down Expand Up @@ -2125,7 +2126,8 @@ def _setup(self, handle: ResourceHandle, task: task_lib.Task,
setup_sh_path = f.name
setup_file = os.path.basename(setup_sh_path)
# Sync the setup script up and run it.
ip_list = handle.external_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS)
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
# Disable connection sharing for setup script to avoid old
Expand Down Expand Up @@ -2522,6 +2524,7 @@ def sync_down_logs(
f'{style.RESET_ALL}')

ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
Expand Down Expand Up @@ -2931,6 +2934,7 @@ def _check_existing_cluster(
def _set_tpu_name(self, handle: ResourceHandle, tpu_name: str) -> None:
"""Sets TPU_NAME on all nodes."""
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)

Expand Down Expand Up @@ -2964,6 +2968,7 @@ def _execute_file_mounts(self, handle: ResourceHandle,
logger.info(f'{fore.CYAN}Processing file mounts.{style.RESET_ALL}')
start = time.time()
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
Expand Down Expand Up @@ -3108,6 +3113,7 @@ def _execute_storage_mounts(self, handle: ResourceHandle,
f'storage mount{plural}.{style.RESET_ALL}')
start = time.time()
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
Expand Down Expand Up @@ -3141,6 +3147,8 @@ def _execute_task_one_node(self, handle: ResourceHandle,
log_dir = os.path.join(self.log_dir, 'tasks')

accelerator_dict = backend_utils.get_task_demands_dict(task)
internal_ips = handle.internal_ips()
assert internal_ips is not None, 'internal_ips is not cached in handle'

codegen = RayCodeGen()
is_local = isinstance(handle.launched_resources.cloud, clouds.Local)
Expand All @@ -3151,9 +3159,7 @@ def _execute_task_one_node(self, handle: ResourceHandle,
setup_log_path=os.path.join(log_dir, 'setup.log'),
is_local=is_local)
codegen.add_gang_scheduling_placement_group(
1,
accelerator_dict,
stable_cluster_internal_ips=handle.internal_ips())
1, accelerator_dict, stable_cluster_internal_ips=internal_ips)

if callable(task.run):
run_fn_code = textwrap.dedent(inspect.getsource(task.run))
Expand Down Expand Up @@ -3196,6 +3202,8 @@ def _execute_task_n_nodes(self, handle: ResourceHandle, task: task_lib.Task,
log_dir_base = self.log_dir
log_dir = os.path.join(log_dir_base, 'tasks')
accelerator_dict = backend_utils.get_task_demands_dict(task)
internal_ips = handle.internal_ips()
assert internal_ips is not None, 'internal_ips is not cached in handle'

# If TPU VM Pods is used, #num_nodes should be #num_tpu_devices
is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(handle.launched_resources)
Expand All @@ -3216,7 +3224,7 @@ def _execute_task_n_nodes(self, handle: ResourceHandle, task: task_lib.Task,
codegen.add_gang_scheduling_placement_group(
num_actual_nodes,
accelerator_dict,
stable_cluster_internal_ips=handle.internal_ips())
stable_cluster_internal_ips=internal_ips)

if callable(task.run):
run_fn_code = textwrap.dedent(inspect.getsource(task.run))
Expand Down
4 changes: 1 addition & 3 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,12 @@ def test_tpu_vm():


# ---------- TPU VM Pod. ----------
# Mark slow because it's expensive to run.
@pytest.mark.slow
def test_tpu_vm_pod():
name = _get_cluster_name()
test = Test(
'tpu_pod',
[
f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --gpus tpu-v2-32',
f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --gpus tpu-v2-32 --use-spot --zone europe-west4-a',
f'sky logs {name} 1', # Ensure the job finished.
f'sky logs {name} 1 --status', # Ensure the job succeeded.
],
Expand Down

0 comments on commit 1955bee

Please sign in to comment.