Skip to content

Commit

Permalink
[UX] warning before launching jobs/serve when using a reauth required…
Browse files Browse the repository at this point in the history
… credentials (#4479)

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* Update sky/backends/cloud_vm_ray_backend.py

Minor fix

* Update sky/clouds/aws.py

Co-authored-by: Romil Bhardwaj <[email protected]>

* wip

* minor changes

* wip

---------

Co-authored-by: hong <[email protected]>
Co-authored-by: Romil Bhardwaj <[email protected]>
  • Loading branch information
3 people authored Jan 6, 2025
1 parent e4939f9 commit 9828f6b
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 0 deletions.
36 changes: 36 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,42 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
return common_utils.dump_yaml_str(new_config)


def get_expirable_clouds(
enabled_clouds: Sequence[clouds.Cloud]) -> List[clouds.Cloud]:
"""Returns a list of clouds that use local credentials and whose credentials can expire.
This function checks each cloud in the provided sequence to determine if it uses local credentials
and if its credentials can expire. If both conditions are met, the cloud is added to the list of
expirable clouds.
Args:
enabled_clouds (Sequence[clouds.Cloud]): A sequence of cloud objects to check.
Returns:
list[clouds.Cloud]: A list of cloud objects that use local credentials and whose credentials can expire.
"""
expirable_clouds = []
local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value
for cloud in enabled_clouds:
remote_identities = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
if remote_identities is None:
remote_identities = schemas.get_default_remote_identity(
str(cloud).lower())

local_credential_expiring = cloud.can_credential_expire()
if isinstance(remote_identities, str):
if remote_identities == local_credentials_value and local_credential_expiring:
expirable_clouds.append(cloud)
elif isinstance(remote_identities, list):
for profile in remote_identities:
if list(profile.values(
))[0] == local_credentials_value and local_credential_expiring:
expirable_clouds.append(cloud)
break
return expirable_clouds


# TODO: too many things happening here - leaky abstraction. Refactor.
@timeline.event
def write_cluster_config(
Expand Down
17 changes: 17 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import sky
from sky import backends
from sky import check as sky_check
from sky import cloud_stores
from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -1996,6 +1997,22 @@ def provision_with_retries(
skip_unnecessary_provisioning else None)

failover_history: List[Exception] = list()
# If the user is using local credentials which may expire, the
# controller may leak resources if the credentials expire while a job
# is running. Here we check the enabled clouds and expiring credentials
# and raise a warning to the user.
if task.is_controller_task():
enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh()
expirable_clouds = backend_utils.get_expirable_clouds(
enabled_clouds)

if len(expirable_clouds) > 0:
warnings = (f'\033[93mWarning: Credentials used for '
f'{expirable_clouds} may expire. Clusters may be '
f'leaked if the credentials expire while jobs '
f'are running. It is recommended to use credentials'
f' that never expire or a service account.\033[0m')
logger.warning(warnings)

# Retrying launchable resources.
while True:
Expand Down
24 changes: 24 additions & 0 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ class AWSIdentityType(enum.Enum):
# region us-east-1 config-file ~/.aws/config
SHARED_CREDENTIALS_FILE = 'shared-credentials-file'

def can_credential_expire(self) -> bool:
"""Check if the AWS identity type can expire.
SSO,IAM_ROLE and CONTAINER_ROLE are temporary credentials and refreshed
automatically. ENV and SHARED_CREDENTIALS_FILE are short-lived
credentials without refresh.
IAM ROLE:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
SSO/Container-role refresh token:
https://docs.aws.amazon.com/solutions/latest/dea-api/auth-refreshtoken.html
"""
# TODO(hong): Add a CLI based check for the expiration of the temporary
# credentials
expirable_types = {
AWSIdentityType.ENV, AWSIdentityType.SHARED_CREDENTIALS_FILE
}
return self in expirable_types


@clouds.CLOUD_REGISTRY.register
class AWS(clouds.Cloud):
Expand Down Expand Up @@ -860,6 +878,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
if os.path.exists(os.path.expanduser(f'~/.aws/{filename}'))
}

@functools.lru_cache(maxsize=1)
def can_credential_expire(self) -> bool:
identity_type = self._current_identity_type()
return identity_type is not None and identity_type.can_credential_expire(
)

def instance_type_exists(self, instance_type):
return service_catalog.instance_type_exists(instance_type, clouds='aws')

Expand Down
4 changes: 4 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
"""
raise NotImplementedError

def can_credential_expire(self) -> bool:
"""Returns whether the cloud credential can expire."""
return False

@classmethod
def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
"""Check the image size from the cloud.
Expand Down
9 changes: 9 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum):

SHARED_CREDENTIALS_FILE = ''

def can_credential_expire(self) -> bool:
return self == GCPIdentityType.SHARED_CREDENTIALS_FILE


@clouds.CLOUD_REGISTRY.register
class GCP(clouds.Cloud):
Expand Down Expand Up @@ -863,6 +866,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
pass
return credentials

@functools.lru_cache(maxsize=1)
def can_credential_expire(self) -> bool:
identity_type = self._get_identity_type()
return identity_type is not None and identity_type.can_credential_expire(
)

@classmethod
def _get_identity_type(cls) -> Optional[GCPIdentityType]:
try:
Expand Down

0 comments on commit 9828f6b

Please sign in to comment.