Skip to content

Commit

Permalink
ENG-504 update cluster, node pool patch methods (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan authored Dec 12, 2024
1 parent a064c00 commit 3dbdb01
Show file tree
Hide file tree
Showing 5 changed files with 630 additions and 344 deletions.
20 changes: 12 additions & 8 deletions neuro_config_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from .client import ConfigClient, ConfigClientBase
from .entities import (
ACMEEnvironment,
AddNodePoolRequest,
ARecord,
AWSCloudProvider,
AWSCredentials,
AWSStorage,
AWSStorageOptions,
AzureCloudProvider,
AzureCredentials,
AzureReplicationType,
AzureStorage,
AzureStorageOptions,
AzureStorageTier,
BucketsConfig,
CloudProvider,
Expand All @@ -36,7 +35,6 @@
GoogleCloudProvider,
GoogleFilestoreTier,
GoogleStorage,
GoogleStorageOptions,
GrafanaCredentials,
HelmRegistryConfig,
IdleJobConfig,
Expand All @@ -52,6 +50,11 @@
OnPremCloudProvider,
OpenStackCredentials,
OrchestratorConfig,
PatchClusterRequest,
PatchNodePoolResourcesRequest,
PatchNodePoolSizeRequest,
PatchOrchestratorConfigRequest,
PutNodePoolRequest,
RegistryConfig,
ResourcePoolType,
ResourcePreset,
Expand All @@ -61,7 +64,6 @@
Storage,
StorageConfig,
StorageInstance,
StorageOptions,
TPUPreset,
TPUResource,
VCDCloudProvider,
Expand All @@ -75,16 +77,15 @@
"ConfigClient",
"ConfigClientBase",
"ACMEEnvironment",
"AddNodePoolRequest",
"ARecord",
"AWSCloudProvider",
"AWSCredentials",
"AWSStorage",
"AWSStorageOptions",
"AzureCloudProvider",
"AzureCredentials",
"AzureReplicationType",
"AzureStorage",
"AzureStorageOptions",
"AzureStorageTier",
"BucketsConfig",
"CloudProvider",
Expand All @@ -106,7 +107,6 @@
"GoogleCloudProvider",
"GoogleFilestoreTier",
"GoogleStorage",
"GoogleStorageOptions",
"GrafanaCredentials",
"HelmRegistryConfig",
"IdleJobConfig",
Expand All @@ -122,6 +122,11 @@
"OnPremCloudProvider",
"OpenStackCredentials",
"OrchestratorConfig",
"PatchClusterRequest",
"PatchNodePoolResourcesRequest",
"PatchNodePoolSizeRequest",
"PatchOrchestratorConfigRequest",
"PutNodePoolRequest",
"RegistryConfig",
"ResourcePoolType",
"ResourcePreset",
Expand All @@ -131,7 +136,6 @@
"Storage",
"StorageConfig",
"StorageInstance",
"StorageOptions",
"TPUPreset",
"TPUResource",
"VCDCloudProvider",
Expand Down
100 changes: 22 additions & 78 deletions neuro_config_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import logging
import sys
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import dataclass
Expand All @@ -14,33 +13,20 @@
from yarl import URL

from .entities import (
BucketsConfig,
AddNodePoolRequest,
CloudProviderOptions,
CloudProviderType,
Cluster,
CredentialsConfig,
DisksConfig,
DNSConfig,
EnergyConfig,
IngressConfig,
MetricsConfig,
MonitoringConfig,
NodePool,
NotificationType,
OrchestratorConfig,
RegistryConfig,
PatchClusterRequest,
PatchNodePoolResourcesRequest,
PatchNodePoolSizeRequest,
PutNodePoolRequest,
ResourcePreset,
SecretsConfig,
StorageConfig,
)
from .factories import EntityFactory, PayloadFactory

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
# why not backports.zoneinfo: https://github.com/pganssle/zoneinfo/issues/125
from backports.zoneinfo._zoneinfo import ZoneInfo

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -170,56 +156,10 @@ async def create_blank_cluster(
return await self.get_cluster(name)

async def patch_cluster(
self,
name: str,
*,
credentials: CredentialsConfig | None = None,
storage: StorageConfig | None = None,
registry: RegistryConfig | None = None,
orchestrator: OrchestratorConfig | None = None,
monitoring: MonitoringConfig | None = None,
secrets: SecretsConfig | None = None,
metrics: MetricsConfig | None = None,
disks: DisksConfig | None = None,
buckets: BucketsConfig | None = None,
ingress: IngressConfig | None = None,
dns: DNSConfig | None = None,
timezone: ZoneInfo | None = None,
energy: EnergyConfig | None = None,
token: str | None = None,
self, name: str, request: PatchClusterRequest, *, token: str | None = None
) -> Cluster:
path = self._endpoints.cluster(name)
payload: dict[str, Any] = {}
if credentials:
payload["credentials"] = self._payload_factory.create_credentials(
credentials
)
if storage:
payload["storage"] = self._payload_factory.create_storage(storage)
if registry:
payload["registry"] = self._payload_factory.create_registry(registry)
if orchestrator:
payload["orchestrator"] = self._payload_factory.create_orchestrator(
orchestrator
)
if monitoring:
payload["monitoring"] = self._payload_factory.create_monitoring(monitoring)
if secrets:
payload["secrets"] = self._payload_factory.create_secrets(secrets)
if metrics:
payload["metrics"] = self._payload_factory.create_metrics(metrics)
if disks:
payload["disks"] = self._payload_factory.create_disks(disks)
if buckets:
payload["buckets"] = self._payload_factory.create_buckets(buckets)
if ingress:
payload["ingress"] = self._payload_factory.create_ingress(ingress)
if dns:
payload["dns"] = self._payload_factory.create_dns(dns)
if timezone:
payload["timezone"] = str(timezone)
if energy:
payload["energy"] = self._payload_factory.create_energy(energy)
payload = self._payload_factory.create_patch_cluster_request(request)
async with self._request(
"PATCH", path, headers=self._create_headers(token=token), json=payload
) as resp:
Expand Down Expand Up @@ -337,16 +277,17 @@ async def list_node_pools(
async def add_node_pool(
self,
cluster_name: str,
node_pool: NodePool,
node_pool: AddNodePoolRequest,
*,
start_deployment: bool = True,
token: str | None = None,
) -> Cluster:
"""Add new node pool to the existing cluster.
Cloud provider should be already set up.
Make sure you use one of the available node pool templates by providing its ID,
if the cluster is deployed in public cloud (AWS / GCP / Azure / VCD).
Make sure you use one of the available node pool templates by providing
its machine type, if the cluster is deployed in public cloud
(AWS / GCP / Azure / VCD).
Args:
cluster_name (str): Name of the cluster within the platform.
Expand All @@ -358,7 +299,7 @@ async def add_node_pool(
Cluster: Cluster instance with applied changes
"""
path = self._endpoints.node_pools(cluster_name)
payload = self._payload_factory.create_node_pool(node_pool)
payload = self._payload_factory.create_add_node_pool_request(node_pool)
async with self._request(
"POST",
path,
Expand All @@ -372,13 +313,13 @@ async def add_node_pool(
async def put_node_pool(
self,
cluster_name: str,
node_pool: NodePool,
node_pool: PutNodePoolRequest,
*,
start_deployment: bool = True,
token: str | None = None,
) -> Cluster:
path = self._endpoints.node_pool(cluster_name, node_pool.name)
payload = self._payload_factory.create_node_pool(node_pool)
payload = self._payload_factory.create_add_node_pool_request(node_pool)
async with self._request(
"PUT",
path,
Expand All @@ -393,16 +334,19 @@ async def patch_node_pool(
self,
cluster_name: str,
node_pool_name: str,
request: PatchNodePoolSizeRequest | PatchNodePoolResourcesRequest,
*,
idle_size: int | None = None,
start_deployment: bool = True,
token: str | None = None,
) -> Cluster:
path = self._endpoints.node_pool(cluster_name, node_pool_name)
payload: dict[str, Any] = {}
if idle_size is not None:
payload["idle_size"] = idle_size
payload = self._payload_factory.create_patch_node_pool_request(request)
async with self._request(
"PATCH", path, headers=self._create_headers(token=token), json=payload
"PATCH",
path,
params={"start_deployment": str(start_deployment).lower()},
headers=self._create_headers(token=token),
json=payload,
) as response:
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)
Expand Down
Loading

0 comments on commit 3dbdb01

Please sign in to comment.