Skip to content

Commit

Permalink
Implement getting OCI offers (#1215)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvstme authored May 15, 2024
1 parent 177036a commit 7b5dda4
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_long_description():
"cachetools",
"dnspython",
"grpcio>=1.50", # indirect
"gpuhunt>=0.0.9rc3",
"gpuhunt>=0.0.9",
"sentry-sdk[fastapi]",
"httpx",
"aiorwlock",
Expand Down
60 changes: 58 additions & 2 deletions src/dstack/_internal/core/backends/oci/compute.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,69 @@
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional

from dstack._internal.core.backends.base.compute import Compute
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.oci import resources
from dstack._internal.core.backends.oci.auth import get_client_config
from dstack._internal.core.backends.oci.config import OCIConfig
from dstack._internal.core.models.instances import InstanceOfferWithAvailability
from dstack._internal.core.backends.oci.region import make_region_clients_map
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOffer,
InstanceOfferWithAvailability,
)
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run

SUPPORTED_SHAPE_FAMILIES = [
"VM.Standard2.",
"VM.DenseIO1.",
"VM.DenseIO2.",
"VM.GPU2.",
"VM.GPU3.",
"VM.GPU.A10.",
]


class OCICompute(Compute):
def __init__(self, config: OCIConfig):
self.config = config
# TODO(#1194): use a separate compartment instead of tenancy root
self.compartment_id = get_client_config(config.creds)["tenancy"]
self.regions = make_region_clients_map(config.regions, config.creds)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
raise NotImplementedError
offers = get_catalog_offers(
backend=BackendType.OCI,
locations=self.config.regions,
requirements=requirements,
extra_filter=_supported_instances,
)

with ThreadPoolExecutor(max_workers=8) as executor:
shapes_quota = resources.get_shapes_quota(self.regions, self.compartment_id, executor)
offers_within_quota = [
offer for offer in offers if offer.instance.name in shapes_quota[offer.region]
]
shapes_availability = resources.get_shapes_availability(
offers_within_quota, self.regions, self.compartment_id, executor
)

offers_with_availability = []
for offer in offers:
if offer.instance.name in shapes_availability[offer.region]:
availability = InstanceAvailability.AVAILABLE
elif offer.instance.name in shapes_quota[offer.region]:
availability = InstanceAvailability.NOT_AVAILABLE
else:
availability = InstanceAvailability.NO_QUOTA
offers_with_availability.append(
InstanceOfferWithAvailability(**offer.dict(), availability=availability)
)

return offers_with_availability

def run_job(
self,
Expand All @@ -29,3 +79,9 @@ def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
) -> None:
raise NotImplementedError


def _supported_instances(offer: InstanceOffer) -> bool:
if "Flex" in offer.instance.name:
return False
return any(map(offer.instance.name.startswith, SUPPORTED_SHAPE_FAMILIES))
13 changes: 13 additions & 0 deletions src/dstack/_internal/core/backends/oci/region.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from typing import Dict, Iterable

import oci
from typing_extensions import Any, List, Mapping
Expand Down Expand Up @@ -28,6 +29,18 @@ def availability_domains(self) -> List[oci.identity.models.AvailabilityDomain]:
return self.identity_client.list_availability_domains(self.client_config["tenancy"]).data


def make_region_clients_map(
region_names: Iterable[str], creds: AnyOCICreds
) -> Dict[str, OCIRegionClient]:
config = get_client_config(creds)
result = {}
for region_name in region_names:
region_config = dict(config)
region_config["region"] = region_name
result[region_name] = OCIRegionClient(region_config)
return result


def get_subscribed_region_names(creds: AnyOCICreds) -> List[str]:
config = get_client_config(creds)
region = OCIRegionClient(config)
Expand Down
140 changes: 140 additions & 0 deletions src/dstack/_internal/core/backends/oci/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from concurrent.futures import Executor, as_completed
from itertools import islice
from typing import Dict, Iterable, List, Mapping, Set

import oci

from dstack._internal.core.backends.oci.region import OCIRegionClient
from dstack._internal.core.models.instances import InstanceOffer

LIST_SHAPES_MAX_LIMIT = 100
CAPACITY_REPORT_MAX_SHAPES = 10 # undocumented, found by experiment


def list_shapes(
client: oci.core.ComputeClient, compartment_id: str
) -> List[oci.core.models.Shape]:
"""
Lists shapes allowed to be used in the region the `client` is bound to.
"""

shapes = []
page = oci.core.compute_client.missing # first page

while page is not None:
resp = client.list_shapes(compartment_id, limit=LIST_SHAPES_MAX_LIMIT, page=page)
shapes.extend(resp.data)
page = resp.headers.get("opc-next-page")

return shapes


def get_shapes_quota(
regions: Mapping[str, OCIRegionClient], compartment_id: str, executor: Executor
) -> Dict[str, Set[str]]:
"""
Returns a mapping of region names to sets of shape names allowed to be used
in these regions.
"""

future_to_region_name = {}
for region_name, region_client in regions.items():
future = executor.submit(list_shapes, region_client.compute_client, compartment_id)
future_to_region_name[future] = region_name

result = {}
for future in as_completed(future_to_region_name):
region_name = future_to_region_name[future]
shape_names = {shape.shape for shape in future.result()}
result[region_name] = shape_names

return result


def check_availability_in_domain(
shape_names: Iterable[str],
availability_domain_name: str,
client: oci.core.ComputeClient,
compartment_id: str,
) -> Set[str]:
"""
Returns a subset of `shape_names` with only the shapes available in
`availability_domain_name`.
"""

unchecked = set(shape_names)
available = set()

while chunk := set(islice(unchecked, CAPACITY_REPORT_MAX_SHAPES)):
unchecked -= chunk
report: oci.core.models.ComputeCapacityReport = client.create_compute_capacity_report(
oci.core.models.CreateComputeCapacityReportDetails(
compartment_id=compartment_id,
availability_domain=availability_domain_name,
shape_availabilities=[
oci.core.models.CreateCapacityReportShapeAvailabilityDetails(
instance_shape=shape_name,
)
for shape_name in chunk
],
)
).data
item: oci.core.models.CapacityReportShapeAvailability

for item in report.shape_availabilities:
if item.availability_status == item.AVAILABILITY_STATUS_AVAILABLE:
available.add(item.instance_shape)

return available


def check_availability_in_region(
shape_names: Iterable[str], region: OCIRegionClient, compartment_id: str
) -> Set[str]:
"""
Returns a subset of `shape_names` with only the shapes available in at least
one availability domain within `region`.
"""

all_shapes = set(shape_names)
available_shapes = set()

for availability_domain in region.availability_domains:
available_shapes |= check_availability_in_domain(
shape_names=all_shapes - available_shapes,
availability_domain_name=availability_domain.name,
client=region.compute_client,
compartment_id=compartment_id,
)

return available_shapes


def get_shapes_availability(
offers: Iterable[InstanceOffer],
regions: Mapping[str, OCIRegionClient],
compartment_id: str,
executor: Executor,
) -> Dict[str, Set[str]]:
"""
Returns a mapping of region names to sets of shape names available in these
regions. Only shapes from `offers` are checked.
"""

shape_names_per_region = {region: set() for region in regions}
for offer in offers:
shape_names_per_region[offer.region].add(offer.instance.name)

future_to_region_name = {}
for region_name, shape_names in shape_names_per_region.items():
future = executor.submit(
check_availability_in_region, shape_names, regions[region_name], compartment_id
)
future_to_region_name[future] = region_name

result = {}
for future in as_completed(future_to_region_name):
region_name = future_to_region_name[future]
result[region_name] = future.result()

return result

0 comments on commit 7b5dda4

Please sign in to comment.