Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-504 update cluster, node pool patch methods #139

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions neuro_config_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
AWSCloudProvider,
AWSCredentials,
AWSStorage,
AWSStorageOptions,
AzureCloudProvider,
AzureCredentials,
AzureReplicationType,
AzureStorage,
AzureStorageOptions,
AzureStorageTier,
BucketsConfig,
CloudProvider,
Expand All @@ -36,7 +34,6 @@
GoogleCloudProvider,
GoogleFilestoreTier,
GoogleStorage,
GoogleStorageOptions,
GrafanaCredentials,
HelmRegistryConfig,
IdleJobConfig,
Expand All @@ -52,6 +49,10 @@
OnPremCloudProvider,
OpenStackCredentials,
OrchestratorConfig,
PatchClusterRequest,
PatchNodePoolResourcesRequest,
PatchNodePoolSizeRequest,
PatchOrchestratorConfigRequest,
RegistryConfig,
ResourcePoolType,
ResourcePreset,
Expand All @@ -61,7 +62,6 @@
Storage,
StorageConfig,
StorageInstance,
StorageOptions,
TPUPreset,
TPUResource,
VCDCloudProvider,
Expand All @@ -79,12 +79,10 @@
"AWSCloudProvider",
"AWSCredentials",
"AWSStorage",
"AWSStorageOptions",
"AzureCloudProvider",
"AzureCredentials",
"AzureReplicationType",
"AzureStorage",
"AzureStorageOptions",
"AzureStorageTier",
"BucketsConfig",
"CloudProvider",
Expand All @@ -106,7 +104,6 @@
"GoogleCloudProvider",
"GoogleFilestoreTier",
"GoogleStorage",
"GoogleStorageOptions",
"GrafanaCredentials",
"HelmRegistryConfig",
"IdleJobConfig",
Expand All @@ -122,6 +119,10 @@
"OnPremCloudProvider",
"OpenStackCredentials",
"OrchestratorConfig",
"PatchClusterRequest",
"PatchNodePoolResourcesRequest",
"PatchNodePoolSizeRequest",
"PatchOrchestratorConfigRequest",
"RegistryConfig",
"ResourcePoolType",
"ResourcePreset",
Expand All @@ -131,7 +132,6 @@
"Storage",
"StorageConfig",
"StorageInstance",
"StorageOptions",
"TPUPreset",
"TPUResource",
"VCDCloudProvider",
Expand Down
85 changes: 13 additions & 72 deletions neuro_config_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import logging
import sys
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import dataclass
Expand All @@ -14,33 +13,18 @@
from yarl import URL

from .entities import (
BucketsConfig,
CloudProviderOptions,
CloudProviderType,
Cluster,
CredentialsConfig,
DisksConfig,
DNSConfig,
EnergyConfig,
IngressConfig,
MetricsConfig,
MonitoringConfig,
NodePool,
NotificationType,
OrchestratorConfig,
RegistryConfig,
PatchClusterRequest,
PatchNodePoolResourcesRequest,
PatchNodePoolSizeRequest,
ResourcePreset,
SecretsConfig,
StorageConfig,
)
from .factories import EntityFactory, PayloadFactory

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
# why not backports.zoneinfo: https://github.com/pganssle/zoneinfo/issues/125
from backports.zoneinfo._zoneinfo import ZoneInfo

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -170,56 +154,10 @@ async def create_blank_cluster(
return await self.get_cluster(name)

async def patch_cluster(
self,
name: str,
*,
credentials: CredentialsConfig | None = None,
storage: StorageConfig | None = None,
registry: RegistryConfig | None = None,
orchestrator: OrchestratorConfig | None = None,
monitoring: MonitoringConfig | None = None,
secrets: SecretsConfig | None = None,
metrics: MetricsConfig | None = None,
disks: DisksConfig | None = None,
buckets: BucketsConfig | None = None,
ingress: IngressConfig | None = None,
dns: DNSConfig | None = None,
timezone: ZoneInfo | None = None,
energy: EnergyConfig | None = None,
token: str | None = None,
self, name: str, request: PatchClusterRequest, *, token: str | None = None
) -> Cluster:
path = self._endpoints.cluster(name)
payload: dict[str, Any] = {}
if credentials:
payload["credentials"] = self._payload_factory.create_credentials(
credentials
)
if storage:
payload["storage"] = self._payload_factory.create_storage(storage)
if registry:
payload["registry"] = self._payload_factory.create_registry(registry)
if orchestrator:
payload["orchestrator"] = self._payload_factory.create_orchestrator(
orchestrator
)
if monitoring:
payload["monitoring"] = self._payload_factory.create_monitoring(monitoring)
if secrets:
payload["secrets"] = self._payload_factory.create_secrets(secrets)
if metrics:
payload["metrics"] = self._payload_factory.create_metrics(metrics)
if disks:
payload["disks"] = self._payload_factory.create_disks(disks)
if buckets:
payload["buckets"] = self._payload_factory.create_buckets(buckets)
if ingress:
payload["ingress"] = self._payload_factory.create_ingress(ingress)
if dns:
payload["dns"] = self._payload_factory.create_dns(dns)
if timezone:
payload["timezone"] = str(timezone)
if energy:
payload["energy"] = self._payload_factory.create_energy(energy)
payload = self._payload_factory.create_patch_cluster_request(request)
async with self._request(
"PATCH", path, headers=self._create_headers(token=token), json=payload
) as resp:
Expand Down Expand Up @@ -393,16 +331,19 @@ async def patch_node_pool(
self,
cluster_name: str,
node_pool_name: str,
request: PatchNodePoolSizeRequest | PatchNodePoolResourcesRequest,
*,
idle_size: int | None = None,
start_deployment: bool = True,
token: str | None = None,
) -> Cluster:
path = self._endpoints.node_pool(cluster_name, node_pool_name)
payload: dict[str, Any] = {}
if idle_size is not None:
payload["idle_size"] = idle_size
payload = self._payload_factory.create_patch_node_pool_request(request)
async with self._request(
"PATCH", path, headers=self._create_headers(token=token), json=payload
"PATCH",
path,
params={"start_deployment": str(start_deployment).lower()},
headers=self._create_headers(token=token),
json=payload,
) as response:
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)
Expand Down
110 changes: 71 additions & 39 deletions neuro_config_client/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def is_vcd(self) -> bool:
class CloudProviderOptions:
type: CloudProviderType
node_pools: list[NodePoolOptions]
storages: list[StorageOptions]


@dataclass(frozen=True)
Expand All @@ -59,41 +58,13 @@ class VCDCloudProviderOptions(CloudProviderOptions):

@dataclass(frozen=True)
class NodePoolOptions:
id: str
machine_type: str
cpu: float
available_cpu: float
memory: int
available_memory: int
gpu: int | None = None
gpu_model: str | None = None


@dataclass(frozen=True)
class StorageOptions:
id: str


@dataclass(frozen=True)
class GoogleStorageOptions(StorageOptions):
tier: GoogleFilestoreTier
min_capacity: int
max_capacity: int


@dataclass(frozen=True)
class AWSStorageOptions(StorageOptions):
performance_mode: EFSPerformanceMode
throughput_mode: EFSThroughputMode
provisioned_throughput_mibps: int | None = None


@dataclass(frozen=True)
class AzureStorageOptions(StorageOptions):
tier: AzureStorageTier
replication_type: AzureReplicationType
min_file_share_size: int
max_file_share_size: int
available_cpu: float | None = None
available_memory: int | None = None
nvidia_gpu: int | None = None
nvidia_gpu_model: str | None = None


class NodeRole(str, enum.Enum):
Expand All @@ -105,7 +76,6 @@ class NodeRole(str, enum.Enum):
@dataclass(frozen=True)
class NodePool:
name: str
id: str | None = None
role: NodeRole = NodeRole.PLATFORM_JOB

min_size: int = 0
Expand All @@ -119,6 +89,7 @@ class NodePool:
available_memory: int | None = None

disk_size: int | None = None
available_disk_size: int | None = None
disk_type: str | None = None

nvidia_gpu: int | None = None
Expand All @@ -142,6 +113,35 @@ class NodePool:
cpu_max_watts: float = 0.0


@dataclass(frozen=True)
class PatchNodePoolSizeRequest:
min_size: int | None = None
max_size: int | None = None
idle_size: int | None = None


@dataclass(frozen=True)
class PatchNodePoolResourcesRequest:
cpu: float
available_cpu: float
memory: int
available_memory: int
disk_size: int
available_disk_size: int

nvidia_gpu: int | None = None
nvidia_gpu_model: str | None = None
amd_gpu: int | None = None
amd_gpu_model: str | None = None
intel_gpu: int | None = None
intel_gpu_model: str | None = None

machine_type: str | None = None

min_size: int | None = None
max_size: int | None = None


@dataclass(frozen=True)
class StorageInstance:
name: str
Expand Down Expand Up @@ -183,7 +183,6 @@ class EFSThroughputMode(str, enum.Enum):

@dataclass(frozen=True)
class AWSStorage(Storage):
id: str
description: str
performance_mode: EFSPerformanceMode
throughput_mode: EFSThroughputMode
Expand Down Expand Up @@ -215,7 +214,6 @@ class GoogleFilestoreTier(str, enum.Enum):

@dataclass(frozen=True)
class GoogleStorage(Storage):
id: str
description: str
tier: GoogleFilestoreTier

Expand Down Expand Up @@ -254,7 +252,6 @@ class AzureReplicationType(str, enum.Enum):

@dataclass(frozen=True)
class AzureStorage(Storage):
id: str
description: str
tier: AzureStorageTier
replication_type: AzureReplicationType
Expand Down Expand Up @@ -518,9 +515,11 @@ class ResourcePoolType:

@dataclass(frozen=True)
class Resources:
cpu_m: int
cpu: float
memory: int
gpu: int = 0
nvidia_gpu: int = 0
amd_gpu: int = 0
intel_gpu: int = 0


@dataclass(frozen=True)
Expand Down Expand Up @@ -552,6 +551,22 @@ class OrchestratorConfig:
idle_jobs: Sequence[IdleJobConfig] = ()


@dataclass
class PatchOrchestratorConfigRequest:
job_hostname_template: str | None = None
job_internal_hostname_template: str | None = None
job_fallback_hostname: str | None = None
job_schedule_timeout_s: float | None = None
job_schedule_scale_up_timeout_s: float | None = None
is_http_ingress_secure: bool | None = None
resource_pool_types: Sequence[ResourcePoolType] | None = None
resource_presets: Sequence[ResourcePreset] | None = None
allow_privileged_mode: bool | None = None
allow_job_priority: bool | None = None
pre_pull_images: Sequence[str] | None = None
idle_jobs: Sequence[IdleJobConfig] | None = None


@dataclass
class ARecord:
name: str
Expand Down Expand Up @@ -618,3 +633,20 @@ class Cluster:
buckets: BucketsConfig | None = None
ingress: IngressConfig | None = None
energy: EnergyConfig | None = None


@dataclass(frozen=True)
class PatchClusterRequest:
credentials: CredentialsConfig | None = None
storage: StorageConfig | None = None
registry: RegistryConfig | None = None
orchestrator: PatchOrchestratorConfigRequest | None = None
monitoring: MonitoringConfig | None = None
secrets: SecretsConfig | None = None
metrics: MetricsConfig | None = None
disks: DisksConfig | None = None
buckets: BucketsConfig | None = None
ingress: IngressConfig | None = None
dns: DNSConfig | None = None
timezone: ZoneInfo | None = None
energy: EnergyConfig | None = None
Loading
Loading