From 9defb38c9a484eed620a44ab6b79f8071dda5578 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 24 Jul 2023 23:57:46 -0700 Subject: [PATCH] fix comments --- sky/provision/__init__.py | 4 +++ sky/provision/aws/instance.py | 48 ++++++++++++++++------------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index e8af44f500c..3b22c1f124c 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -8,6 +8,7 @@ import functools import importlib import inspect + from sky import status_lib @@ -46,6 +47,9 @@ def query_instances( """Query instances. Returns a dictionary of instance IDs and status. + + A None status means the instance is marked as "terminated" + or "terminating". """ raise NotImplementedError diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index 270d6bfbb04..5f523e7a5a4 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -2,6 +2,7 @@ from typing import Dict, List, Any, Optional from botocore import config + from sky.adaptors import aws from sky import status_lib @@ -11,6 +12,20 @@ TAG_RAY_NODE_KIND = 'ray-node-type' +def _default_ec2_resource(region: str) -> Any: + return aws.resource( + 'ec2', + region_name=region, + config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) + + +def _cluster_name_filter(cluster_name: str) -> List[Dict[str, Any]]: + return [{ + 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', + 'Values': [cluster_name], + }] + + def _filter_instances(ec2, filters: List[Dict[str, Any]], included_instances: Optional[List[str]], excluded_instances: Optional[List[str]]): @@ -33,18 +48,11 @@ def query_instances( cluster_name: str, provider_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" assert provider_config is not None, (cluster_name, provider_config) region = provider_config['region'] - ec2 = aws.resource( - 'ec2', - region_name=region, - config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) - filters = [ - { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', - 'Values': [cluster_name], - }, - ] + ec2 = _default_ec2_resource(region) + filters = _cluster_name_filter(cluster_name) instances = ec2.instances.filter(Filters=filters) status_map = { 'pending': status_lib.ClusterStatus.INIT, @@ -71,19 +79,13 @@ def stop_instances( """See sky/provision/__init__.py""" assert provider_config is not None, (cluster_name, provider_config) region = provider_config['region'] - ec2 = aws.resource( - 'ec2', - region_name=region, - config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) + ec2 = _default_ec2_resource(region) filters = [ { 'Name': 'instance-state-name', 'Values': ['pending', 'running'], }, - { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', - 'Values': [cluster_name], - }, + *_cluster_name_filter(cluster_name), ] if worker_only: filters.append({ @@ -108,20 +110,14 @@ def terminate_instances( """See sky/provision/__init__.py""" assert provider_config is not None, (cluster_name, provider_config) region = provider_config['region'] - ec2 = aws.resource( - 'ec2', - region_name=region, - config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES})) + ec2 = _default_ec2_resource(region) filters = [ { 'Name': 'instance-state-name', # exclude 'shutting-down' or 'terminated' states 'Values': ['pending', 'running', 'stopping', 'stopped'], }, - { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', - 'Values': [cluster_name], - }, + *_cluster_name_filter(cluster_name), ] if worker_only: filters.append({