Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Improve completion of accelerator name when cloud is specified #3014

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,9 @@ def _set_accelerators(

# Canonicalize the accelerator names.
accelerators = {
accelerator_registry.canonicalize_accelerator_name(acc):
acc_count for acc, acc_count in accelerators.items()
accelerator_registry.canonicalize_accelerator_name(
acc, self._cloud): acc_count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: What happens if a user has 1 cloud only as indicated by sky check but self._cloud is not specified? I guess we can leave that to the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point! We can consider inferring the cloud for the whole Resources when only one cloud is enabled. Let's leave it to the future as that may need more test. : )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this might be debatable, as this may make an original working yaml fails, if a user add a new cloud.

for acc, acc_count in accelerators.items()
}

acc, _ = list(accelerators.items())[0]
Expand Down Expand Up @@ -1304,8 +1305,9 @@ def __setstate__(self, state):
accelerators = state.pop('_accelerators', None)
if accelerators is not None:
accelerators = {
accelerator_registry.canonicalize_accelerator_name(acc):
acc_count for acc, acc_count in accelerators.items()
accelerator_registry.canonicalize_accelerator_name(
acc, cloud=None): acc_count
for acc, acc_count in accelerators.items()
}
state['_accelerators'] = accelerators

Expand Down
16 changes: 14 additions & 2 deletions sky/utils/accelerator_registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
"""Accelerator registry."""
import typing
from typing import Optional

from sky.clouds import service_catalog
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky import clouds

# Canonicalized names of all accelerators (except TPUs) supported by SkyPilot.
# NOTE: Must include accelerators supported for local clusters.
#
Expand Down Expand Up @@ -67,8 +73,13 @@ def is_schedulable_non_gpu_accelerator(accelerator_name: str) -> bool:
return False


def canonicalize_accelerator_name(accelerator: str) -> str:
def canonicalize_accelerator_name(accelerator: str,
cloud: Optional['clouds.Cloud']) -> str:
"""Returns the canonical accelerator name."""
cloud_str = None
if cloud is not None:
cloud_str = str(cloud).lower()

# TPU names are always lowercase.
if accelerator.lower().startswith('tpu-'):
return accelerator.lower()
Expand All @@ -84,7 +95,8 @@ def canonicalize_accelerator_name(accelerator: str) -> str:
# To cover such cases, we should search the accelerator name
# in the service catalog.
searched = service_catalog.list_accelerators(name_filter=accelerator,
case_sensitive=False)
case_sensitive=False,
clouds=cloud_str)
names = list(searched.keys())

# Exact match.
Expand Down