Skip to content

Commit

Permalink
add resource presets client methods
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan committed Jul 14, 2022
1 parent a1cf101 commit db8a65c
Showing 1 changed file with 127 additions and 60 deletions.
187 changes: 127 additions & 60 deletions neuro_config_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from collections.abc import Sequence
from dataclasses import dataclass
from types import TracebackType
from typing import Any

Expand All @@ -24,6 +25,7 @@
NotificationType,
OrchestratorConfig,
RegistryConfig,
ResourcePreset,
SecretsConfig,
StorageConfig,
)
Expand All @@ -32,6 +34,44 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class _Endpoints:
clusters: URL
cloud_providers: URL

def cluster(self, cluster_name: str) -> URL:
return self.clusters / cluster_name

def node_pools(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "cloud_provider/node_pools"

def node_pool(self, cluster_name: str, node_pool_name: str) -> URL:
return self.node_pools(cluster_name) / node_pool_name

def storages(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "cloud_provider/storages"

def storage(self, cluster_name: str, storage_name: str) -> URL:
return self.storages(cluster_name) / storage_name

def notifications(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "notifications"

def resource_presets(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "orchestrator/resource_presets"

def resource_preset(self, cluster_name: str, preset_name: str) -> URL:
return self.resource_presets(cluster_name) / preset_name

@classmethod
def create(cls, base_url: URL) -> _Endpoints:
clusters = base_url / "api/v1/clusters"
return cls(
clusters=clusters,
cloud_providers=base_url / "api/v1/cloud_providers",
)


class ConfigClient:
def __init__(
self,
Expand All @@ -40,8 +80,7 @@ def __init__(
timeout: aiohttp.ClientTimeout = aiohttp.client.DEFAULT_TIMEOUT,
trace_configs: Sequence[aiohttp.TraceConfig] = (),
):
self._clusters_url = url / "api/v1/clusters"
self._cloud_providers_url = url / "api/v1/cloud_providers"
self._endpoints = _Endpoints.create(url)
self._token = token
self._timeout = timeout
self._trace_configs = trace_configs
Expand Down Expand Up @@ -82,7 +121,9 @@ def _create_headers(self, *, token: str | None = None) -> dict[str, str]:
async def get_clusters(self, *, token: str | None = None) -> Sequence[Cluster]:
assert self._client
headers = self._create_headers(token=token)
async with self._client.get(self._clusters_url, headers=headers) as response:
async with self._client.get(
self._endpoints.clusters, headers=headers
) as response:
response.raise_for_status()
payload = await response.json()
return [self._entity_factory.create_cluster(p) for p in payload]
Expand All @@ -91,7 +132,7 @@ async def get_cluster(self, name: str, *, token: str | None = None) -> Cluster:
assert self._client
headers = self._create_headers(token=token)
async with self._client.get(
self._clusters_url / name, headers=headers
self._endpoints.cluster(name), headers=headers
) as response:
response.raise_for_status()
payload = await response.json()
Expand All @@ -110,7 +151,7 @@ async def create_blank_cluster(
payload = {"name": name, "token": service_token}
try:
async with self._client.post(
self._clusters_url, headers=headers, json=payload
self._endpoints.clusters, headers=headers, json=payload
) as resp:
resp.raise_for_status()
resp_payload = await resp.json()
Expand Down Expand Up @@ -139,7 +180,7 @@ async def patch_cluster(
dns: DNSConfig | None = None,
) -> Cluster:
assert self._client
url = self._clusters_url / name
url = self._endpoints.cluster(name)
headers = self._create_headers(token=token)
payload: dict[str, Any] = {}
if credentials:
Expand Down Expand Up @@ -177,7 +218,7 @@ async def delete_cluster(self, name: str, *, token: str | None = None) -> None:
assert self._client
headers = self._create_headers(token=token)
async with self._client.delete(
self._clusters_url / name, headers=headers
self._endpoints.cluster(name), headers=headers
) as resp:
resp.raise_for_status()

Expand All @@ -193,7 +234,7 @@ async def add_storage(
) -> Cluster:
assert self._client
try:
url = self._clusters_url / cluster_name / "cloud_provider/storages"
url = self._endpoints.storages(cluster_name)
headers = self._create_headers(token=token)
payload: dict[str, Any] = {"name": storage_name}
if size is not None:
Expand Down Expand Up @@ -223,18 +264,9 @@ async def patch_storage(
assert self._client
try:
if storage_name:
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/storages"
/ storage_name
)
url = self._endpoints.storage(cluster_name, storage_name)
else:
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/storages/default/entry"
)
url = self._endpoints.storage(cluster_name, "default/entry")
headers = self._create_headers(token=token)
payload: dict[str, Any] = {}
if ready is not None:
Expand All @@ -261,12 +293,7 @@ async def remove_storage(
) -> Cluster:
assert self._client
try:
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/storages"
/ storage_name
)
url = self._endpoints.storage(cluster_name, storage_name)
headers = self._create_headers(token=token)
async with self._client.delete(
url.with_query(start_deployment=str(start_deployment).lower()),
Expand All @@ -288,29 +315,20 @@ async def get_node_pool(
token: str | None = None,
) -> NodePool:
assert self._client
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider"
/ "node_pools"
/ node_pool_name
)

url = self._endpoints.node_pool(cluster_name, node_pool_name)
headers = self._create_headers(token=token)
async with self._client.get(url=url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_node_pool(resp_payload)

async def get_node_pools(
self,
cluster_name: str,
*,
token: str | None = None,
self, cluster_name: str, *, token: str | None = None
) -> list[NodePool]:
assert self._client
url = self._clusters_url / cluster_name / "cloud_provider/node_pools"

url = self._endpoints.node_pools(cluster_name)
headers = self._create_headers(token=token)
async with self._client.get(url=url, headers=headers) as response:
response.raise_for_status()
Expand All @@ -324,10 +342,11 @@ async def get_node_pool_templates(
token: str | None = None,
) -> list[NodePoolTemplate]:
assert self._client

if cloud_provider_type == CloudProviderType.ON_PREM:
raise ValueError("Templates are not supported in onprem clusters.")

url = self._cloud_providers_url / cloud_provider_type.value
url = self._endpoints.cloud_providers / cloud_provider_type.value
headers = self._create_headers(token=token)
async with self._client.get(url=url, headers=headers) as response:
response.raise_for_status()
Expand Down Expand Up @@ -365,7 +384,7 @@ async def add_node_pool(
"""
assert self._client

url = self._clusters_url / cluster_name / "cloud_provider/node_pools"
url = self._endpoints.node_pools(cluster_name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_node_pool(node_pool)
async with self._client.post(
Expand All @@ -387,13 +406,7 @@ async def put_node_pool(
) -> Cluster:
assert self._client

url = (
self._clusters_url
/ cluster_name
/ "cloud_provider"
/ "node_pools"
/ node_pool.name
)
url = self._endpoints.node_pool(cluster_name, node_pool.name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_node_pool(node_pool)

Expand All @@ -416,13 +429,7 @@ async def patch_node_pool(
) -> Cluster:
assert self._client

url = (
self._clusters_url
/ cluster_name
/ "cloud_provider"
/ "node_pools"
/ node_pool_name
)
url = self._endpoints.node_pool(cluster_name, node_pool_name)
headers = self._create_headers(token=token)
payload: dict[str, Any] = {}
if idle_size is not None:
Expand All @@ -443,12 +450,7 @@ async def delete_node_pool(
) -> Cluster:
assert self._client

url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/node_pools"
/ node_pool_name
)
url = self._endpoints.node_pool(cluster_name, node_pool_name)
headers = self._create_headers(token=token)
async with self._client.delete(
url.with_query(start_deployment=str(start_deployment).lower()),
Expand All @@ -467,10 +469,75 @@ async def notify(
token: str | None = None,
) -> None:
assert self._client
url = self._clusters_url / cluster_name / "notifications"

url = self._endpoints.notifications(cluster_name)
headers = self._create_headers(token=token)
payload = {"notification_type": notification_type.value}
if message:
payload["message"] = message
async with self._client.post(url, headers=headers, json=payload) as response:
response.raise_for_status()

async def get_resource_presets(
self, cluster_name: str, *, token: str | None = None
) -> list[ResourcePreset]:
assert self._client

url = self._endpoints.resource_presets(cluster_name)
headers = self._create_headers(token=token)
async with self._client.get(url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return [
self._entity_factory.create_resource_preset(p) for p in resp_payload
]

async def get_resource_preset(
self, cluster_name: str, preset_name: str, *, token: str | None = None
) -> ResourcePreset:
assert self._client

url = self._endpoints.resource_preset(cluster_name, preset_name)
headers = self._create_headers(token=token)
async with self._client.get(url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_resource_preset(resp_payload)

async def add_resource_preset(
self, cluster_name: str, preset: ResourcePreset, *, token: str | None = None
) -> Cluster:
assert self._client

url = self._endpoints.resource_presets(cluster_name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_resource_preset(preset)
async with self._client.post(url, headers=headers, json=payload) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)

async def put_resource_preset(
self, cluster_name: str, preset: ResourcePreset, *, token: str | None = None
) -> Cluster:
assert self._client

url = self._endpoints.resource_preset(cluster_name, preset.name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_resource_preset(preset)
async with self._client.put(url, headers=headers, json=payload) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)

async def delete_resource_preset(
self, cluster_name: str, preset_name: str, *, token: str | None = None
) -> Cluster:
assert self._client

url = self._endpoints.resource_preset(cluster_name, preset_name)
headers = self._create_headers(token=token)
async with self._client.delete(url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)

0 comments on commit db8a65c

Please sign in to comment.