diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 6e79469a819..bf92f442d2f 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -650,6 +650,42 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]): return common_utils.dump_yaml_str(new_config) +def get_expirable_clouds( + enabled_clouds: Sequence[clouds.Cloud]) -> List[clouds.Cloud]: + """Returns a list of clouds that use local credentials and whose credentials can expire. + + This function checks each cloud in the provided sequence to determine if it uses local credentials + and if its credentials can expire. If both conditions are met, the cloud is added to the list of + expirable clouds. + + Args: + enabled_clouds (Sequence[clouds.Cloud]): A sequence of cloud objects to check. + + Returns: + list[clouds.Cloud]: A list of cloud objects that use local credentials and whose credentials can expire. + """ + expirable_clouds = [] + local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value + for cloud in enabled_clouds: + remote_identities = skypilot_config.get_nested( + (str(cloud).lower(), 'remote_identity'), None) + if remote_identities is None: + remote_identities = schemas.get_default_remote_identity( + str(cloud).lower()) + + local_credential_expiring = cloud.can_credential_expire() + if isinstance(remote_identities, str): + if remote_identities == local_credentials_value and local_credential_expiring: + expirable_clouds.append(cloud) + elif isinstance(remote_identities, list): + for profile in remote_identities: + if list(profile.values( + ))[0] == local_credentials_value and local_credential_expiring: + expirable_clouds.append(cloud) + break + return expirable_clouds + + # TODO: too many things happening here - leaky abstraction. Refactor. @timeline.event def write_cluster_config( diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 156f43181b2..c972928cd7d 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -26,6 +26,7 @@ import sky from sky import backends +from sky import check as sky_check from sky import cloud_stores from sky import clouds from sky import exceptions @@ -1996,6 +1997,22 @@ def provision_with_retries( skip_unnecessary_provisioning else None) failover_history: List[Exception] = list() + # If the user is using local credentials which may expire, the + # controller may leak resources if the credentials expire while a job + # is running. Here we check the enabled clouds and expiring credentials + # and raise a warning to the user. + if task.is_controller_task(): + enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh() + expirable_clouds = backend_utils.get_expirable_clouds( + enabled_clouds) + + if len(expirable_clouds) > 0: + warnings = (f'\033[93mWarning: Credentials used for ' + f'{expirable_clouds} may expire. Clusters may be ' + f'leaked if the credentials expire while jobs ' + f'are running. It is recommended to use credentials' + f' that never expire or a service account.\033[0m') + logger.warning(warnings) # Retrying launchable resources. while True: diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index c665263e22e..a86a87f4feb 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -103,6 +103,24 @@ class AWSIdentityType(enum.Enum): # region us-east-1 config-file ~/.aws/config SHARED_CREDENTIALS_FILE = 'shared-credentials-file' + def can_credential_expire(self) -> bool: + """Check if the AWS identity type can expire. + + SSO,IAM_ROLE and CONTAINER_ROLE are temporary credentials and refreshed + automatically. ENV and SHARED_CREDENTIALS_FILE are short-lived + credentials without refresh. + IAM ROLE: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + SSO/Container-role refresh token: + https://docs.aws.amazon.com/solutions/latest/dea-api/auth-refreshtoken.html + """ + # TODO(hong): Add a CLI based check for the expiration of the temporary + # credentials + expirable_types = { + AWSIdentityType.ENV, AWSIdentityType.SHARED_CREDENTIALS_FILE + } + return self in expirable_types + @clouds.CLOUD_REGISTRY.register class AWS(clouds.Cloud): @@ -860,6 +878,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]: if os.path.exists(os.path.expanduser(f'~/.aws/{filename}')) } + @functools.lru_cache(maxsize=1) + def can_credential_expire(self) -> bool: + identity_type = self._current_identity_type() + return identity_type is not None and identity_type.can_credential_expire( + ) + def instance_type_exists(self, instance_type): return service_catalog.instance_type_exists(instance_type, clouds='aws') diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 455baeaf5d9..2cb45ca14fc 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]: """ raise NotImplementedError + def can_credential_expire(self) -> bool: + """Returns whether the cloud credential can expire.""" + return False + @classmethod def get_image_size(cls, image_id: str, region: Optional[str]) -> float: """Check the image size from the cloud. diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index ff200f84147..3502fee8e1c 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum): SHARED_CREDENTIALS_FILE = '' + def can_credential_expire(self) -> bool: + return self == GCPIdentityType.SHARED_CREDENTIALS_FILE + @clouds.CLOUD_REGISTRY.register class GCP(clouds.Cloud): @@ -863,6 +866,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]: pass return credentials + @functools.lru_cache(maxsize=1) + def can_credential_expire(self) -> bool: + identity_type = self._get_identity_type() + return identity_type is not None and identity_type.can_credential_expire( + ) + @classmethod def _get_identity_type(cls) -> Optional[GCPIdentityType]: try: