Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add resource presets client methods #90

Merged
merged 1 commit into from
Jul 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 127 additions & 60 deletions neuro_config_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from collections.abc import Sequence
from dataclasses import dataclass
from types import TracebackType
from typing import Any

Expand All @@ -24,6 +25,7 @@
NotificationType,
OrchestratorConfig,
RegistryConfig,
ResourcePreset,
SecretsConfig,
StorageConfig,
)
Expand All @@ -32,6 +34,44 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class _Endpoints:
clusters: URL
cloud_providers: URL

def cluster(self, cluster_name: str) -> URL:
return self.clusters / cluster_name

def node_pools(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "cloud_provider/node_pools"

def node_pool(self, cluster_name: str, node_pool_name: str) -> URL:
return self.node_pools(cluster_name) / node_pool_name

def storages(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "cloud_provider/storages"

def storage(self, cluster_name: str, storage_name: str) -> URL:
return self.storages(cluster_name) / storage_name

def notifications(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "notifications"

def resource_presets(self, cluster_name: str) -> URL:
return self.cluster(cluster_name) / "orchestrator/resource_presets"

def resource_preset(self, cluster_name: str, preset_name: str) -> URL:
return self.resource_presets(cluster_name) / preset_name

@classmethod
def create(cls, base_url: URL) -> _Endpoints:
clusters = base_url / "api/v1/clusters"
return cls(
clusters=clusters,
cloud_providers=base_url / "api/v1/cloud_providers",
)


class ConfigClient:
def __init__(
self,
Expand All @@ -40,8 +80,7 @@ def __init__(
timeout: aiohttp.ClientTimeout = aiohttp.client.DEFAULT_TIMEOUT,
trace_configs: Sequence[aiohttp.TraceConfig] = (),
):
self._clusters_url = url / "api/v1/clusters"
self._cloud_providers_url = url / "api/v1/cloud_providers"
self._endpoints = _Endpoints.create(url)
self._token = token
self._timeout = timeout
self._trace_configs = trace_configs
Expand Down Expand Up @@ -82,7 +121,9 @@ def _create_headers(self, *, token: str | None = None) -> dict[str, str]:
async def get_clusters(self, *, token: str | None = None) -> Sequence[Cluster]:
assert self._client
headers = self._create_headers(token=token)
async with self._client.get(self._clusters_url, headers=headers) as response:
async with self._client.get(
self._endpoints.clusters, headers=headers
) as response:
response.raise_for_status()
payload = await response.json()
return [self._entity_factory.create_cluster(p) for p in payload]
Expand All @@ -91,7 +132,7 @@ async def get_cluster(self, name: str, *, token: str | None = None) -> Cluster:
assert self._client
headers = self._create_headers(token=token)
async with self._client.get(
self._clusters_url / name, headers=headers
self._endpoints.cluster(name), headers=headers
) as response:
response.raise_for_status()
payload = await response.json()
Expand All @@ -110,7 +151,7 @@ async def create_blank_cluster(
payload = {"name": name, "token": service_token}
try:
async with self._client.post(
self._clusters_url, headers=headers, json=payload
self._endpoints.clusters, headers=headers, json=payload
) as resp:
resp.raise_for_status()
resp_payload = await resp.json()
Expand Down Expand Up @@ -139,7 +180,7 @@ async def patch_cluster(
dns: DNSConfig | None = None,
) -> Cluster:
assert self._client
url = self._clusters_url / name
url = self._endpoints.cluster(name)
headers = self._create_headers(token=token)
payload: dict[str, Any] = {}
if credentials:
Expand Down Expand Up @@ -177,7 +218,7 @@ async def delete_cluster(self, name: str, *, token: str | None = None) -> None:
assert self._client
headers = self._create_headers(token=token)
async with self._client.delete(
self._clusters_url / name, headers=headers
self._endpoints.cluster(name), headers=headers
) as resp:
resp.raise_for_status()

Expand All @@ -193,7 +234,7 @@ async def add_storage(
) -> Cluster:
assert self._client
try:
url = self._clusters_url / cluster_name / "cloud_provider/storages"
url = self._endpoints.storages(cluster_name)
headers = self._create_headers(token=token)
payload: dict[str, Any] = {"name": storage_name}
if size is not None:
Expand Down Expand Up @@ -223,18 +264,9 @@ async def patch_storage(
assert self._client
try:
if storage_name:
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/storages"
/ storage_name
)
url = self._endpoints.storage(cluster_name, storage_name)
else:
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/storages/default/entry"
)
url = self._endpoints.storage(cluster_name, "default/entry")
headers = self._create_headers(token=token)
payload: dict[str, Any] = {}
if ready is not None:
Expand All @@ -261,12 +293,7 @@ async def remove_storage(
) -> Cluster:
assert self._client
try:
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/storages"
/ storage_name
)
url = self._endpoints.storage(cluster_name, storage_name)
headers = self._create_headers(token=token)
async with self._client.delete(
url.with_query(start_deployment=str(start_deployment).lower()),
Expand All @@ -288,29 +315,20 @@ async def get_node_pool(
token: str | None = None,
) -> NodePool:
assert self._client
url = (
self._clusters_url
/ cluster_name
/ "cloud_provider"
/ "node_pools"
/ node_pool_name
)

url = self._endpoints.node_pool(cluster_name, node_pool_name)
headers = self._create_headers(token=token)
async with self._client.get(url=url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_node_pool(resp_payload)

async def get_node_pools(
self,
cluster_name: str,
*,
token: str | None = None,
self, cluster_name: str, *, token: str | None = None
) -> list[NodePool]:
assert self._client
url = self._clusters_url / cluster_name / "cloud_provider/node_pools"

url = self._endpoints.node_pools(cluster_name)
headers = self._create_headers(token=token)
async with self._client.get(url=url, headers=headers) as response:
response.raise_for_status()
Expand All @@ -324,10 +342,11 @@ async def get_node_pool_templates(
token: str | None = None,
) -> list[NodePoolTemplate]:
assert self._client

if cloud_provider_type == CloudProviderType.ON_PREM:
raise ValueError("Templates are not supported in onprem clusters.")

url = self._cloud_providers_url / cloud_provider_type.value
url = self._endpoints.cloud_providers / cloud_provider_type.value
headers = self._create_headers(token=token)
async with self._client.get(url=url, headers=headers) as response:
response.raise_for_status()
Expand Down Expand Up @@ -365,7 +384,7 @@ async def add_node_pool(
"""
assert self._client

url = self._clusters_url / cluster_name / "cloud_provider/node_pools"
url = self._endpoints.node_pools(cluster_name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_node_pool(node_pool)
async with self._client.post(
Expand All @@ -387,13 +406,7 @@ async def put_node_pool(
) -> Cluster:
assert self._client

url = (
self._clusters_url
/ cluster_name
/ "cloud_provider"
/ "node_pools"
/ node_pool.name
)
url = self._endpoints.node_pool(cluster_name, node_pool.name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_node_pool(node_pool)

Expand All @@ -416,13 +429,7 @@ async def patch_node_pool(
) -> Cluster:
assert self._client

url = (
self._clusters_url
/ cluster_name
/ "cloud_provider"
/ "node_pools"
/ node_pool_name
)
url = self._endpoints.node_pool(cluster_name, node_pool_name)
headers = self._create_headers(token=token)
payload: dict[str, Any] = {}
if idle_size is not None:
Expand All @@ -443,12 +450,7 @@ async def delete_node_pool(
) -> Cluster:
assert self._client

url = (
self._clusters_url
/ cluster_name
/ "cloud_provider/node_pools"
/ node_pool_name
)
url = self._endpoints.node_pool(cluster_name, node_pool_name)
headers = self._create_headers(token=token)
async with self._client.delete(
url.with_query(start_deployment=str(start_deployment).lower()),
Expand All @@ -467,10 +469,75 @@ async def notify(
token: str | None = None,
) -> None:
assert self._client
url = self._clusters_url / cluster_name / "notifications"

url = self._endpoints.notifications(cluster_name)
headers = self._create_headers(token=token)
payload = {"notification_type": notification_type.value}
if message:
payload["message"] = message
async with self._client.post(url, headers=headers, json=payload) as response:
response.raise_for_status()

async def get_resource_presets(
self, cluster_name: str, *, token: str | None = None
) -> list[ResourcePreset]:
assert self._client

url = self._endpoints.resource_presets(cluster_name)
headers = self._create_headers(token=token)
async with self._client.get(url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return [
self._entity_factory.create_resource_preset(p) for p in resp_payload
]

async def get_resource_preset(
self, cluster_name: str, preset_name: str, *, token: str | None = None
) -> ResourcePreset:
assert self._client

url = self._endpoints.resource_preset(cluster_name, preset_name)
headers = self._create_headers(token=token)
async with self._client.get(url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_resource_preset(resp_payload)

async def add_resource_preset(
self, cluster_name: str, preset: ResourcePreset, *, token: str | None = None
) -> Cluster:
assert self._client

url = self._endpoints.resource_presets(cluster_name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_resource_preset(preset)
async with self._client.post(url, headers=headers, json=payload) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)

async def put_resource_preset(
self, cluster_name: str, preset: ResourcePreset, *, token: str | None = None
) -> Cluster:
assert self._client

url = self._endpoints.resource_preset(cluster_name, preset.name)
headers = self._create_headers(token=token)
payload = self._payload_factory.create_resource_preset(preset)
async with self._client.put(url, headers=headers, json=payload) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)

async def delete_resource_preset(
self, cluster_name: str, preset_name: str, *, token: str | None = None
) -> Cluster:
assert self._client

url = self._endpoints.resource_preset(cluster_name, preset_name)
headers = self._create_headers(token=token)
async with self._client.delete(url, headers=headers) as response:
response.raise_for_status()
resp_payload = await response.json()
return self._entity_factory.create_cluster(resp_payload)