Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[User identity] Fix identity check #1550

Merged
merged 9 commits into from
Dec 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 67 additions & 39 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from sky.utils import common_utils
from sky.utils import command_runner
from sky.utils import env_options
from sky.utils import log_utils
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import tpu_utils
Expand Down Expand Up @@ -1695,8 +1694,8 @@ def check_owner_identity(cluster_name: str) -> None:
elif owner_identity != current_user_identity:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterOwnerIdentityMismatchError(
f'Cluster {cluster_name!r} ({cloud}) is owned by account '
f'{owner_identity!r}, but the currently activated account '
f'{cluster_name!r} ({cloud}) is owned by account '
f'{owner_identity!r}, but the activated account '
f'is {current_user_identity!r}.')


Expand Down Expand Up @@ -1818,13 +1817,23 @@ def _update_cluster_status(
design of the cluster status and transition, please refer to the
sky/design_docs/cluster_status.md

Args:
cluster_name: The name of the cluster.
acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock
before updating the status.
need_owner_identity_check: Whether to check the owner identity before updating

Returns:
If the cluster is terminated or does not exist, return None.
Otherwise returns the input record with status and ip potentially updated.
If the cluster is terminated or does not exist, return None.
Otherwise returns the input record with status and handle potentially updated.

Raises:
exceptions.ClusterOwnerIdentityMismatchError: if the current user is not the
same as the user who created the cluster.
exceptions.CloudUserIdentityError: if we fail to get the current user
identity.
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
fetched from the cloud provider.
"""
if not acquire_per_cluster_status_lock:
return _update_cluster_status_no_lock(cluster_name)
Expand All @@ -1843,21 +1852,28 @@ def _update_cluster_status(
return global_user_state.get_cluster_from_name(cluster_name)


@timeline.event
def refresh_cluster_status_handle(
cluster_name: str,
*,
force_refresh: bool = False,
acquire_per_cluster_status_lock: bool = True,
suppress_error: bool = False
) -> Tuple[Optional[global_user_state.ClusterStatus],
Optional[backends.Backend.ResourceHandle]]:
"""Refresh the cluster status and return the status and handle.
def _refresh_cluster_record(
cluster_name: str,
*,
force_refresh: bool = False,
acquire_per_cluster_status_lock: bool = True
) -> Optional[Dict[str, Any]]:
"""Refresh the cluster, and return the possibly updated record.

This function will also check the owner identity of the cluster, and raise
exceptions if the current user is not the same as the user who created the
cluster.

Args:
cluster_name: The name of the cluster.
force_refresh: refresh the cluster status as long as the cluster exists.
acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock
before updating the status.

Returns:
If the cluster is terminated or does not exist, return None.
Otherwise returns the cluster record.

Raises:
exceptions.ClusterOwnerIdentityMismatchError: if the current user is not the
same as the user who created the cluster.
Expand All @@ -1866,9 +1882,10 @@ def refresh_cluster_status_handle(
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
"""

record = global_user_state.get_cluster_from_name(cluster_name)
if record is None:
return None, None
return None
check_owner_identity(cluster_name)

handle = record['handle']
Expand All @@ -1878,21 +1895,32 @@ def refresh_cluster_status_handle(
record['status'] != global_user_state.ClusterStatus.STOPPED and
record['autostop'] >= 0)
if force_refresh or has_autostop or use_spot:
try:
record = _update_cluster_status(cluster_name,
acquire_per_cluster_status_lock=
acquire_per_cluster_status_lock)
if record is None:
return None, None
except (exceptions.ClusterOwnerIdentityMismatchError,
exceptions.CloudUserIdentityError,
exceptions.ClusterStatusFetchingError) as e:
if suppress_error:
logger.debug(
f'Failed to refresh cluster {cluster_name!r} due to {e}'
)
return None, None
raise
record = _update_cluster_status(
cluster_name,
acquire_per_cluster_status_lock=acquire_per_cluster_status_lock)
return record


@timeline.event
def refresh_cluster_status_handle(
cluster_name: str,
*,
force_refresh: bool = False,
acquire_per_cluster_status_lock: bool = True,
) -> Tuple[Optional[global_user_state.ClusterStatus],
Optional[backends.Backend.ResourceHandle]]:
"""Refresh the cluster, and return the possibly updated status and handle.

This is a wrapper of refresh_cluster_record, which returns the status and
handle of the cluster.
Please refer to the docstring of refresh_cluster_record for the details.
"""
record = _refresh_cluster_record(
cluster_name,
force_refresh=force_refresh,
acquire_per_cluster_status_lock=acquire_per_cluster_status_lock)
if record is None:
return None, None
return record['status'], record['handle']


Expand Down Expand Up @@ -2036,10 +2064,13 @@ def _is_local_cluster(record):

def _refresh_cluster(cluster_name):
try:
record = _update_cluster_status(
cluster_name, acquire_per_cluster_status_lock=True)
record = _refresh_cluster_record(
cluster_name,
force_refresh=True,
acquire_per_cluster_status_lock=True)
except (exceptions.ClusterStatusFetchingError,
exceptions.ClusterOwnerIdentityMismatchError) as e:
exceptions.ClusterOwnerIdentityMismatchError,
exceptions.ClusterStatusFetchingError) as e:
record = {'status': 'UNKNOWN', 'error': e}
progress.update(task, advance=1)
return record
Expand Down Expand Up @@ -2085,11 +2116,8 @@ def _refresh_cluster(cluster_name):
plural = 's' if len(failed_clusters) > 1 else ''
logger.warning(f'{yellow}Failed to refresh status for '
f'{len(failed_clusters)} cluster{plural}:{reset}')
table = log_utils.create_table(['Cluster', 'Error'])
for cluster_name, e in failed_clusters:
table.add_row([cluster_name, str(e)])
logger.warning(table)

logger.warning(f' {bright}{cluster_name}{reset}: {e}')
return kept_records


Expand Down