Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ResourceHandle semantics #1481

Merged
merged 11 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from sky.utils import command_runner
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import tpu_utils
from sky.utils import ux_utils
from sky.utils import validator
from sky.usage import usage_lib
Expand Down Expand Up @@ -1152,10 +1153,11 @@ def get_node_ips(cluster_yaml: str,
ray_config = common_utils.read_yaml(cluster_yaml)
use_tpu_vm = ray_config['provider'].get('_has_tpus', False)
if use_tpu_vm:
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
assert expected_num_nodes == 1, 'TPU VM only supports single node for now.'
if len(ips) != expected_num_nodes:
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
if len(ips) != tpu_utils.get_num_tpu_devices(handle.launched_resources):
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
return ips

if get_internal_ips:
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
Expand Down Expand Up @@ -1625,7 +1627,7 @@ def _update_cluster_status_no_lock(
# in ray's get IPs vs. ray runtime failing.
external_ips = handle.external_ips(use_cached_ips=False)
# This happens to a stopped TPU VM as we use gcloud to query the IP.
if len(external_ips) == 0:
if external_ips is None or len(external_ips) == 0:
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD)
if handle.launched_nodes == 1:
Expand Down Expand Up @@ -1693,8 +1695,10 @@ def _update_cluster_status_no_lock(
# that the cluster is partially preempted.
# TODO(zhwu): the definition of INIT should be audited/changed.
# Adding a new status UNHEALTHY for abnormal status can be a choice.
global_user_state.set_cluster_status(
cluster_name, global_user_state.ClusterStatus.INIT)
global_user_state.add_or_update_cluster(cluster_name,
handle,
ready=False,
is_launch=False)
return global_user_state.get_cluster_from_name(cluster_name)
# Now is_abnormal is False: either node_statuses is empty or all nodes are STOPPED.
backend = backends.CloudVmRayBackend()
Expand Down
28 changes: 18 additions & 10 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,12 @@ def _tpu_pod_setup(self, cluster_yaml: str,
run setup or launch ray cluster on TPU VM Pod nodes.
"""
ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml)
all_ips = cluster_handle.external_ips()
all_ips = cluster_handle.external_ips(use_cached_ips=False)
num_tpu_devices = tpu_utils.get_num_tpu_devices(
cluster_handle.launched_resources)
if len(all_ips) != num_tpu_devices:
if all_ips is None or len(all_ips) != num_tpu_devices:
raise RuntimeError(
f'Number of nodes IPs: {len(all_ips)} does not'
f'Nodes IPs: {all_ips} does not'
f'match number of TPU devices: {num_tpu_devices}.')

# Get the private IP of head node for connecting Ray cluster.
Expand Down Expand Up @@ -1709,7 +1709,7 @@ def _update_stable_cluster_ips(self,

def internal_ips(self,
max_attempts: int = 1,
use_cached_ips: bool = True):
use_cached_ips: bool = True) -> Optional[List[str]]:
if not use_cached_ips:
self._update_stable_cluster_ips(max_attempts=max_attempts)
if self.stable_internal_external_ips is not None:
Expand All @@ -1718,7 +1718,7 @@ def internal_ips(self,

def external_ips(self,
max_attempts: int = 1,
use_cached_ips: bool = True):
use_cached_ips: bool = True) -> Optional[List[str]]:
if not use_cached_ips:
self._update_stable_cluster_ips(max_attempts=max_attempts)
if self.stable_internal_external_ips is not None:
Expand Down Expand Up @@ -2046,6 +2046,7 @@ def _sync_workdir(self, handle: ResourceHandle, workdir: Path) -> None:
fore = colorama.Fore
style = colorama.Style
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
full_workdir = os.path.abspath(os.path.expanduser(workdir))

# These asserts have been validated at Task construction time.
Expand Down Expand Up @@ -2125,7 +2126,8 @@ def _setup(self, handle: ResourceHandle, task: task_lib.Task,
setup_sh_path = f.name
setup_file = os.path.basename(setup_sh_path)
# Sync the setup script up and run it.
ip_list = handle.external_ips(max_attempts=_FETCH_IP_MAX_ATTEMPTS)
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
# Disable connection sharing for setup script to avoid old
Expand Down Expand Up @@ -2522,6 +2524,7 @@ def sync_down_logs(
f'{style.RESET_ALL}')

ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
Expand Down Expand Up @@ -2931,6 +2934,7 @@ def _check_existing_cluster(
def _set_tpu_name(self, handle: ResourceHandle, tpu_name: str) -> None:
"""Sets TPU_NAME on all nodes."""
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)

Expand Down Expand Up @@ -2964,6 +2968,7 @@ def _execute_file_mounts(self, handle: ResourceHandle,
logger.info(f'{fore.CYAN}Processing file mounts.{style.RESET_ALL}')
start = time.time()
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
Expand Down Expand Up @@ -3108,6 +3113,7 @@ def _execute_storage_mounts(self, handle: ResourceHandle,
f'storage mount{plural}.{style.RESET_ALL}')
start = time.time()
ip_list = handle.external_ips()
assert ip_list is not None, 'external_ips is not cached in handle'
ssh_credentials = backend_utils.ssh_credential_from_yaml(
handle.cluster_yaml)
runners = command_runner.SSHCommandRunner.make_runner_list(
Expand Down Expand Up @@ -3141,6 +3147,8 @@ def _execute_task_one_node(self, handle: ResourceHandle,
log_dir = os.path.join(self.log_dir, 'tasks')

accelerator_dict = backend_utils.get_task_demands_dict(task)
internal_ips = handle.internal_ips()
assert internal_ips is not None, 'internal_ips is not cached in handle'

codegen = RayCodeGen()
is_local = isinstance(handle.launched_resources.cloud, clouds.Local)
Expand All @@ -3151,9 +3159,7 @@ def _execute_task_one_node(self, handle: ResourceHandle,
setup_log_path=os.path.join(log_dir, 'setup.log'),
is_local=is_local)
codegen.add_gang_scheduling_placement_group(
1,
accelerator_dict,
stable_cluster_internal_ips=handle.internal_ips())
1, accelerator_dict, stable_cluster_internal_ips=internal_ips)

if callable(task.run):
run_fn_code = textwrap.dedent(inspect.getsource(task.run))
Expand Down Expand Up @@ -3196,6 +3202,8 @@ def _execute_task_n_nodes(self, handle: ResourceHandle, task: task_lib.Task,
log_dir_base = self.log_dir
log_dir = os.path.join(log_dir_base, 'tasks')
accelerator_dict = backend_utils.get_task_demands_dict(task)
internal_ips = handle.internal_ips()
assert internal_ips is not None, 'internal_ips is not cached in handle'

# If TPU VM Pods is used, #num_nodes should be #num_tpu_devices
is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(handle.launched_resources)
Expand All @@ -3216,7 +3224,7 @@ def _execute_task_n_nodes(self, handle: ResourceHandle, task: task_lib.Task,
codegen.add_gang_scheduling_placement_group(
num_actual_nodes,
accelerator_dict,
stable_cluster_internal_ips=handle.internal_ips())
stable_cluster_internal_ips=internal_ips)

if callable(task.run):
run_fn_code = textwrap.dedent(inspect.getsource(task.run))
Expand Down
4 changes: 1 addition & 3 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,12 @@ def test_tpu_vm():


# ---------- TPU VM Pod. ----------
# Mark slow because it's expensive to run.
@pytest.mark.slow
def test_tpu_vm_pod():
name = _get_cluster_name()
test = Test(
'tpu_pod',
[
f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --gpus tpu-v2-32',
f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --gpus tpu-v2-32 --use-spot --zone europe-west4-a',
f'sky logs {name} 1', # Ensure the job finished.
f'sky logs {name} 1 --status', # Ensure the job succeeded.
],
Expand Down