Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark committed Jul 25, 2023
1 parent 065de2b commit 9defb38
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
4 changes: 4 additions & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import importlib
import inspect

from sky import status_lib


Expand Down Expand Up @@ -46,6 +47,9 @@ def query_instances(
"""Query instances.
Returns a dictionary of instance IDs and status.
A None status means the instance is marked as "terminated"
or "terminating".
"""
raise NotImplementedError

Expand Down
48 changes: 22 additions & 26 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Any, Optional

from botocore import config

from sky.adaptors import aws
from sky import status_lib

Expand All @@ -11,6 +12,20 @@
TAG_RAY_NODE_KIND = 'ray-node-type'


def _default_ec2_resource(region: str) -> Any:
return aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))


def _cluster_name_filter(cluster_name: str) -> List[Dict[str, Any]]:
return [{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
}]


def _filter_instances(ec2, filters: List[Dict[str, Any]],
included_instances: Optional[List[str]],
excluded_instances: Optional[List[str]]):
Expand All @@ -33,18 +48,11 @@ def query_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
filters = [
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
]
ec2 = _default_ec2_resource(region)
filters = _cluster_name_filter(cluster_name)
instances = ec2.instances.filter(Filters=filters)
status_map = {
'pending': status_lib.ClusterStatus.INIT,
Expand All @@ -71,19 +79,13 @@ def stop_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
ec2 = _default_ec2_resource(region)
filters = [
{
'Name': 'instance-state-name',
'Values': ['pending', 'running'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand All @@ -108,20 +110,14 @@ def terminate_instances(
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
config=config.Config(retries={'max_attempts': BOTO_MAX_RETRIES}))
ec2 = _default_ec2_resource(region)
filters = [
{
'Name': 'instance-state-name',
# exclude 'shutting-down' or 'terminated' states
'Values': ['pending', 'running', 'stopping', 'stopped'],
},
{
'Name': f'tag:{TAG_RAY_CLUSTER_NAME}',
'Values': [cluster_name],
},
*_cluster_name_filter(cluster_name),
]
if worker_only:
filters.append({
Expand Down

0 comments on commit 9defb38

Please sign in to comment.