From 2150055b5ef074fa6c4119d716c1068c4747554e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 Dec 2021 10:57:24 +0200 Subject: [PATCH] Upgrade to Python 3.9 (#1844) * Upgrade to Python 3.9 * fix --- .github/workflows/ci.yaml | 6 +- .pre-commit-config.yaml | 6 +- Dockerfile | 7 +- Makefile | 4 +- platform_api/api.py | 19 +- platform_api/cluster.py | 7 +- platform_api/cluster_config.py | 3 +- platform_api/cluster_config_factory.py | 19 +- platform_api/config.py | 3 +- platform_api/config_client.py | 3 +- platform_api/config_factory.py | 7 +- platform_api/handlers/job_request_builder.py | 14 +- platform_api/handlers/jobs_handler.py | 61 ++-- platform_api/handlers/validators.py | 9 +- platform_api/kube_cluster.py | 5 +- platform_api/log.py | 3 +- platform_api/orchestrator/base.py | 7 +- .../orchestrator/base_postgres_storage.py | 9 +- .../orchestrator/billing_log/service.py | 9 +- .../orchestrator/billing_log/storage.py | 25 +- platform_api/orchestrator/job.py | 21 +- .../orchestrator/job_policy_enforcer.py | 24 +- platform_api/orchestrator/job_request.py | 73 +++-- platform_api/orchestrator/jobs_poller.py | 15 +- platform_api/orchestrator/jobs_service.py | 15 +- .../orchestrator/jobs_storage/base.py | 55 ++-- .../orchestrator/jobs_storage/in_memory.py | 15 +- .../orchestrator/jobs_storage/postgres.py | 34 +- platform_api/orchestrator/kube_client.py | 245 +++++++------- .../orchestrator/kube_orchestrator.py | 43 +-- platform_api/orchestrator/poller_service.py | 31 +- platform_api/poller_main.py | 7 +- platform_api/resource.py | 3 +- platform_api/utils/asyncio.py | 21 +- platform_api/utils/retry.py | 7 +- platform_api/utils/stream.py | 5 +- platform_api/utils/update_notifier.py | 10 +- setup.cfg | 3 +- tests/conftest.py | 4 +- tests/integration/admin.py | 5 +- tests/integration/api.py | 43 +-- tests/integration/auth.py | 26 +- tests/integration/conftest.py | 39 +-- tests/integration/diskapi.py | 11 +- tests/integration/docker.py | 3 +- tests/integration/notifications.py | 5 +- tests/integration/postgres.py | 2 +- tests/integration/secrets.py | 11 +- tests/integration/test_api.py | 309 +++++++++--------- tests/integration/test_config_client.py | 13 +- tests/integration/test_jobs_storage.py | 18 +- tests/integration/test_kube_orchestrator.py | 43 +-- tests/integration/test_notifications.py | 23 +- tests/unit/conftest.py | 78 ++--- tests/unit/test_billing_log_service.py | 3 +- tests/unit/test_cluster_config_factory.py | 19 +- tests/unit/test_config.py | 3 +- tests/unit/test_job.py | 63 ++-- tests/unit/test_job_policy_enforcer.py | 40 +-- tests/unit/test_job_rest_validator.py | 9 +- tests/unit/test_job_service.py | 3 +- tests/unit/test_jobs_poller.py | 3 +- tests/unit/test_jobs_poller_client.py | 3 +- tests/unit/test_kube_orchestrator.py | 30 +- tests/unit/test_models.py | 39 +-- 65 files changed, 811 insertions(+), 890 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 04a4cd8c3..fd25cd7c1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -30,12 +30,12 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.8.10 + python-version: 3.9.9 - name: Cache packages uses: actions/cache@v2.1.7 with: path: ~/.cache/pip - key: ${{ runner.os }}-py-3.8.10-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('setup.cfg') }} + key: ${{ runner.os }}-py-3.9.9-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('setup.cfg') }} - name: Install dependencies run: make setup - name: Lint @@ -113,7 +113,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.8.10 + python-version: 3.9.9 - name: Install Helm uses: azure/setup-helm@v1 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d916c54ca..40b4653a4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,11 +38,15 @@ repos: files: | docs/spelling_wordlist.txt| .gitignore +- repo: https://github.com/sondrelg/pep585-upgrade + rev: 'v1.0.1' + hooks: + - id: upgrade-type-hints - repo: https://github.com/asottile/pyupgrade rev: 'v2.29.1' hooks: - id: pyupgrade - args: ['--py36-plus'] + args: ['--py39-plus'] - repo: https://gitlab.com/pycqa/flake8 rev: '3.9.2' hooks: diff --git a/Dockerfile b/Dockerfile index c5824c502..ca8ccf55b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,4 @@ -ARG PYTHON_VERSION=3.8.12 -ARG PYTHON_BASE=buster - -FROM python:${PYTHON_VERSION} AS installer +FROM python:3.9.9-slim-bullseye AS installer ENV PATH=/root/.local/bin:$PATH @@ -12,7 +9,7 @@ COPY dist /tmp/dist RUN ls /tmp/dist RUN pip install --user --find-links /tmp/dist platform-api -FROM python:${PYTHON_VERSION}-${PYTHON_BASE} AS service +FROM python:3.9.9-slim-bullseye AS service LABEL org.opencontainers.image.source = "https://github.com/neuro-inc/platform-api" diff --git a/Makefile b/Makefile index 74635c78b..9c9372e8a 100644 --- a/Makefile +++ b/Makefile @@ -52,9 +52,7 @@ docker_build: rm -rf build dist pip install -U build python -m build - docker build \ - --build-arg PYTHON_BASE=slim-buster \ - -t $(IMAGE_NAME):latest . + docker build -t $(IMAGE_NAME):latest . docker_push: docker tag $(IMAGE_NAME):latest $(IMAGE_REPO):$(IMAGE_TAG) diff --git a/platform_api/api.py b/platform_api/api.py index 07b897ad2..bc7648bd9 100644 --- a/platform_api/api.py +++ b/platform_api/api.py @@ -1,11 +1,12 @@ import asyncio import logging +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from contextlib import AsyncExitStack -from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Sequence +from importlib.metadata import version +from typing import Any import aiohttp.web import aiohttp_cors -import pkg_resources from aiohttp.web import HTTPUnauthorized from aiohttp.web_urldispatcher import AbstractRoute from aiohttp_security import check_permission @@ -62,7 +63,7 @@ class ApiHandler: - def register(self, app: aiohttp.web.Application) -> List[AbstractRoute]: + def register(self, app: aiohttp.web.Application) -> list[AbstractRoute]: return app.add_routes((aiohttp.web.get("/ping", self.handle_ping),)) @notrace @@ -104,7 +105,7 @@ async def handle_clusters_sync( return aiohttp.web.Response(text="OK") async def handle_config(self, request: aiohttp.web.Request) -> aiohttp.web.Response: - data: Dict[str, Any] = {} + data: dict[str, Any] = {} try: user = await authorized_user(request) @@ -147,7 +148,7 @@ async def handle_config(self, request: aiohttp.web.Request) -> aiohttp.web.Respo def _convert_cluster_config_to_payload( self, user_cluster_config: UserClusterConfig - ) -> Dict[str, Any]: + ) -> dict[str, Any]: cluster_config = user_cluster_config.config orgs = user_cluster_config.orgs presets = [ @@ -169,8 +170,8 @@ def _convert_cluster_config_to_payload( "orgs": orgs, } - def _convert_preset_to_payload(self, preset: Preset) -> Dict[str, Any]: - payload: Dict[str, Any] = { + def _convert_preset_to_payload(self, preset: Preset) -> dict[str, Any]: + payload: dict[str, Any] = { "name": preset.name, "credits_per_hour": str(preset.credits_per_hour), "cpu": preset.cpu, @@ -243,7 +244,7 @@ async def create_jobs_app(config: Config) -> aiohttp.web.Application: return jobs_app -package_version = pkg_resources.get_distribution("platform-api").version +package_version = version(__package__) async def add_version_to_header( @@ -252,7 +253,7 @@ async def add_version_to_header( response.headers["X-Service-Version"] = f"platform-api/{package_version}" -def make_tracing_trace_configs(config: Config) -> List[aiohttp.TraceConfig]: +def make_tracing_trace_configs(config: Config) -> list[aiohttp.TraceConfig]: trace_configs = [] if config.zipkin: diff --git a/platform_api/cluster.py b/platform_api/cluster.py index 0c3bf444c..79c634fec 100644 --- a/platform_api/cluster.py +++ b/platform_api/cluster.py @@ -1,8 +1,9 @@ import asyncio import logging from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, Optional from aiorwlock import RWLock @@ -216,10 +217,10 @@ class ClusterConfigRegistry: def __init__( self, ) -> None: - self._records: Dict[str, ClusterConfig] = {} + self._records: dict[str, ClusterConfig] = {} @property - def cluster_names(self) -> List[str]: + def cluster_names(self) -> list[str]: return list(self._records) def get(self, name: str) -> ClusterConfig: diff --git a/platform_api/cluster_config.py b/platform_api/cluster_config.py index 72ca6d150..ec15dd816 100644 --- a/platform_api/cluster_config.py +++ b/platform_api/cluster_config.py @@ -1,7 +1,8 @@ +from collections.abc import Sequence from dataclasses import dataclass, field from enum import Enum from pathlib import PurePath -from typing import Optional, Sequence +from typing import Optional from yarl import URL diff --git a/platform_api/cluster_config_factory.py b/platform_api/cluster_config_factory.py index e027d272f..179a9b767 100644 --- a/platform_api/cluster_config_factory.py +++ b/platform_api/cluster_config_factory.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Sequence from decimal import Decimal -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Optional import trafaret as t from yarl import URL @@ -14,12 +15,12 @@ class ClusterConfigFactory: def create_cluster_configs( - self, payload: Sequence[Dict[str, Any]] + self, payload: Sequence[dict[str, Any]] ) -> Sequence[ClusterConfig]: configs = (self.create_cluster_config(p) for p in payload) return [c for c in configs if c] - def create_cluster_config(self, payload: Dict[str, Any]) -> Optional[ClusterConfig]: + def create_cluster_config(self, payload: dict[str, Any]) -> Optional[ClusterConfig]: try: _cluster_config_validator.check(payload) return ClusterConfig( @@ -31,7 +32,7 @@ def create_cluster_config(self, payload: Dict[str, Any]) -> Optional[ClusterConf logging.warning(f"failed to parse cluster config: {err}") return None - def _create_ingress_config(self, payload: Dict[str, Any]) -> IngressConfig: + def _create_ingress_config(self, payload: dict[str, Any]) -> IngressConfig: return IngressConfig( registry_url=URL(payload["registry"]["url"]), storage_url=URL(payload["storage"]["url"]), @@ -43,7 +44,7 @@ def _create_ingress_config(self, payload: Dict[str, Any]) -> IngressConfig: buckets_url=URL(payload["buckets"]["url"]), ) - def _create_presets(self, payload: Dict[str, Any]) -> List[Preset]: + def _create_presets(self, payload: dict[str, Any]) -> list[Preset]: result = [] for preset in payload.get("resource_presets", []): result.append( @@ -64,7 +65,7 @@ def _create_presets(self, payload: Dict[str, Any]) -> List[Preset]: return result def _create_orchestrator_config( - self, payload: Dict[str, Any] + self, payload: dict[str, Any] ) -> OrchestratorConfig: orchestrator = payload["orchestrator"] presets = self._create_presets(orchestrator) @@ -93,7 +94,7 @@ def _create_orchestrator_config( ) def _create_tpu_preset( - self, payload: Optional[Dict[str, Any]] + self, payload: Optional[dict[str, Any]] ) -> Optional[TPUPreset]: if not payload: return None @@ -102,7 +103,7 @@ def _create_tpu_preset( type=payload["type"], software_version=payload["software_version"] ) - def _create_resource_pool_type(self, payload: Dict[str, Any]) -> ResourcePoolType: + def _create_resource_pool_type(self, payload: dict[str, Any]) -> ResourcePoolType: cpu = payload.get("cpu") memory_mb = payload.get("memory_mb") return ResourcePoolType( @@ -121,7 +122,7 @@ def _create_resource_pool_type(self, payload: Dict[str, Any]) -> ResourcePoolTyp ) def _create_tpu_resource( - self, payload: Optional[Dict[str, Any]] + self, payload: Optional[dict[str, Any]] ) -> Optional[TPUResource]: if not payload: return None diff --git a/platform_api/config.py b/platform_api/config.py index 798ec988c..b82d7c80f 100644 --- a/platform_api/config.py +++ b/platform_api/config.py @@ -1,7 +1,8 @@ +from collections.abc import Sequence from dataclasses import dataclass, field from datetime import timedelta from decimal import Decimal -from typing import Optional, Sequence +from typing import Optional from alembic.config import Config as AlembicConfig from yarl import URL diff --git a/platform_api/config_client.py b/platform_api/config_client.py index 538224269..05279f0aa 100644 --- a/platform_api/config_client.py +++ b/platform_api/config_client.py @@ -1,5 +1,6 @@ +from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Optional, Sequence +from typing import Any, Optional import aiohttp from multidict import CIMultiDict diff --git a/platform_api/config_factory.py b/platform_api/config_factory.py index 411daef83..e7f16fbd7 100644 --- a/platform_api/config_factory.py +++ b/platform_api/config_factory.py @@ -1,8 +1,9 @@ import os import pathlib +from collections.abc import Sequence from decimal import Decimal from pathlib import PurePath -from typing import Dict, List, Optional, Sequence +from typing import Optional from alembic.config import Config as AlembicConfig from yarl import URL @@ -29,7 +30,7 @@ class EnvironConfigFactory: - def __init__(self, environ: Optional[Dict[str, str]] = None): + def __init__(self, environ: Optional[dict[str, str]] = None): self._environ = environ or os.environ def _get_bool(self, name: str, default: bool = False) -> bool: @@ -331,7 +332,7 @@ def create_registry(self) -> RegistryConfig: ) def create_storages(self) -> Sequence[StorageConfig]: - result: List[StorageConfig] = [] + result: list[StorageConfig] = [] i = 0 while True: diff --git a/platform_api/handlers/job_request_builder.py b/platform_api/handlers/job_request_builder.py index 63533ee5c..b4a75903d 100644 --- a/platform_api/handlers/job_request_builder.py +++ b/platform_api/handlers/job_request_builder.py @@ -1,5 +1,5 @@ from pathlib import PurePath -from typing import Any, Dict +from typing import Any from platform_api.orchestrator.job_request import ( Container, @@ -13,7 +13,7 @@ ) -def create_container_from_payload(payload: Dict[str, Any]) -> Container: +def create_container_from_payload(payload: dict[str, Any]) -> Container: if "container" in payload: # Deprecated. Use flat structure payload = payload["container"] @@ -62,7 +62,7 @@ def create_container_from_payload(payload: Dict[str, Any]) -> Container: ) -def create_resources_from_payload(payload: Dict[str, Any]) -> ContainerResources: +def create_resources_from_payload(payload: dict[str, Any]) -> ContainerResources: tpu = None if "tpu" in payload: tpu = create_tpu_resource_from_payload(payload["tpu"]) @@ -76,13 +76,13 @@ def create_resources_from_payload(payload: Dict[str, Any]) -> ContainerResources ) -def create_tpu_resource_from_payload(payload: Dict[str, Any]) -> ContainerTPUResource: +def create_tpu_resource_from_payload(payload: dict[str, Any]) -> ContainerTPUResource: return ContainerTPUResource( type=payload["type"], software_version=payload["software_version"] ) -def create_volume_from_payload(payload: Dict[str, Any]) -> ContainerVolume: +def create_volume_from_payload(payload: dict[str, Any]) -> ContainerVolume: dst_path = PurePath(payload["dst_path"]) return ContainerVolume.create( payload["src_storage_uri"], @@ -91,13 +91,13 @@ def create_volume_from_payload(payload: Dict[str, Any]) -> ContainerVolume: ) -def create_secret_volume_from_payload(payload: Dict[str, Any]) -> SecretContainerVolume: +def create_secret_volume_from_payload(payload: dict[str, Any]) -> SecretContainerVolume: return SecretContainerVolume.create( uri=payload["src_secret_uri"], dst_path=PurePath(payload["dst_path"]) ) -def create_disk_volume_from_payload(payload: Dict[str, Any]) -> DiskContainerVolume: +def create_disk_volume_from_payload(payload: dict[str, Any]) -> DiskContainerVolume: return DiskContainerVolume.create( uri=payload["src_disk_uri"], dst_path=PurePath(payload["dst_path"]), diff --git a/platform_api/handlers/jobs_handler.py b/platform_api/handlers/jobs_handler.py index 42ccab4fb..3f955c0d4 100644 --- a/platform_api/handlers/jobs_handler.py +++ b/platform_api/handlers/jobs_handler.py @@ -2,18 +2,9 @@ import json import logging from collections import defaultdict +from collections.abc import AsyncIterator, Sequence, Set from dataclasses import dataclass, replace -from typing import ( - AbstractSet, - Any, - AsyncIterator, - Dict, - List, - Optional, - Sequence, - Set, - Tuple, -) +from typing import Any, Optional import aiohttp.web import iso8601 @@ -90,8 +81,8 @@ def create_job_request_validator( storage_scheme: str = "storage", ) -> t.Trafaret: def _check_no_schedule_timeout_for_scheduled_jobs( - payload: Dict[str, Any] - ) -> Dict[str, Any]: + payload: dict[str, Any] + ) -> dict[str, Any]: if "schedule_timeout" in payload and payload["scheduler_enabled"]: raise t.DataError("schedule_timeout is not allowed for scheduled jobs") return payload @@ -109,7 +100,7 @@ def multiname_key( ) -> t.Key: _empty = object() - def _take_first(data: Dict[str, Any]) -> Dict[str, Any]: + def _take_first(data: dict[str, Any]) -> dict[str, Any]: for key in keys: if data[key] is not _empty: return trafaret(data[key]) @@ -166,7 +157,7 @@ def _take_first(data: Dict[str, Any]) -> Dict[str, Any]: def create_job_preset_validator(presets: Sequence[Preset]) -> t.Trafaret: - def _check_no_resources(payload: Dict[str, Any]) -> Dict[str, Any]: + def _check_no_resources(payload: dict[str, Any]) -> dict[str, Any]: if "container" in payload: resources = payload["container"].get("resources") else: @@ -177,7 +168,7 @@ def _check_no_resources(payload: Dict[str, Any]) -> Dict[str, Any]: raise t.DataError("Both preset and resources are not allowed") return payload - def _set_preset_resources(payload: Dict[str, Any]) -> Dict[str, Any]: + def _set_preset_resources(payload: dict[str, Any]) -> dict[str, Any]: preset_name = payload["preset_name"] preset = {p.name: p for p in presets}[preset_name] payload["scheduler_enabled"] = preset.scheduler_enabled @@ -306,7 +297,7 @@ def create_job_set_materialized_validator() -> t.Trafaret: def create_job_update_max_run_time_minutes_validator() -> t.Trafaret: - def _check_exactly_one(payload: Dict[str, Any]) -> Dict[str, Any]: + def _check_exactly_one(payload: dict[str, Any]) -> dict[str, Any]: if not payload or ( "max_run_time_minutes" in payload and "additional_max_run_time_minutes" in payload @@ -336,8 +327,8 @@ def create_drop_progress_validator() -> t.Trafaret: ) -def convert_job_container_to_json(container: Container) -> Dict[str, Any]: - ret: Dict[str, Any] = { +def convert_job_container_to_json(container: Container) -> dict[str, Any]: + ret: dict[str, Any] = { "image": container.image, "env": container.env, "volumes": [], @@ -347,7 +338,7 @@ def convert_job_container_to_json(container: Container) -> Dict[str, Any]: if container.command is not None: ret["command"] = container.command - resources: Dict[str, Any] = { + resources: dict[str, Any] = { "cpu": container.resources.cpu, "memory_mb": container.resources.memory_mb, } @@ -392,7 +383,7 @@ def convert_job_container_to_json(container: Container) -> Dict[str, Any]: return ret -def convert_container_volume_to_json(volume: ContainerVolume) -> Dict[str, Any]: +def convert_container_volume_to_json(volume: ContainerVolume) -> dict[str, Any]: return { "src_storage_uri": str(volume.uri), "dst_path": str(volume.dst_path), @@ -400,14 +391,14 @@ def convert_container_volume_to_json(volume: ContainerVolume) -> Dict[str, Any]: } -def convert_secret_volume_to_json(volume: SecretContainerVolume) -> Dict[str, Any]: +def convert_secret_volume_to_json(volume: SecretContainerVolume) -> dict[str, Any]: return { "src_secret_uri": str(volume.to_uri()), "dst_path": str(volume.dst_path), } -def convert_disk_volume_to_json(volume: DiskContainerVolume) -> Dict[str, Any]: +def convert_disk_volume_to_json(volume: DiskContainerVolume) -> dict[str, Any]: return { "src_disk_uri": str(volume.disk.to_uri()), "dst_path": str(volume.dst_path), @@ -415,13 +406,13 @@ def convert_disk_volume_to_json(volume: DiskContainerVolume) -> Dict[str, Any]: } -def convert_job_to_job_response(job: Job) -> Dict[str, Any]: +def convert_job_to_job_response(job: Job) -> dict[str, Any]: assert ( job.cluster_name ), "empty cluster name must be already replaced with `default`" history = job.status_history current_status = history.current - response_payload: Dict[str, Any] = { + response_payload: dict[str, Any] = { "id": job.id, "owner": job.owner, "cluster_name": job.cluster_name, @@ -491,7 +482,7 @@ def infer_permissions_from_container( registry_host: str, cluster_name: str, org_name: Optional[str], -) -> List[Permission]: +) -> list[Permission]: permissions = [ Permission(uri=str(make_job_uri(user, cluster_name, org_name)), action="write") ] @@ -814,7 +805,7 @@ async def limit_filter( async def _iter_filtered_jobs( self, bulk_job_filter: "BulkJobFilter", reverse: bool, limit: Optional[int] ) -> AsyncIterator[Job]: - def job_key(job: Job) -> Tuple[float, str, Job]: + def job_key(job: Job) -> tuple[float, str, Job]: return job.status_history.created_at_timestamp, job.id, job if bulk_job_filter.shared_ids: @@ -1055,7 +1046,7 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore class BulkJobFilter: bulk_filter: Optional[JobFilter] - shared_ids: Set[str] + shared_ids: set[str] shared_ids_filter: Optional[JobFilter] @@ -1068,11 +1059,11 @@ def __init__( self._has_access_to_all: bool = False self._has_clusters_shared_all: bool = False - self._clusters_shared_any: Dict[str, Dict[str, Set[str]]] = defaultdict( + self._clusters_shared_any: dict[str, dict[str, set[str]]] = defaultdict( lambda: defaultdict(set) ) - self._owners_shared_any: Set[str] = set() - self._shared_ids: Set[str] = set() + self._owners_shared_any: set[str] = set() + self._shared_ids: set[str] = set() def build(self) -> BulkJobFilter: self._traverse_access_tree() @@ -1192,9 +1183,7 @@ def _create_bulk_filter(self) -> Optional[JobFilter]: ) return bulk_filter - def _optimize_clusters_owners( - self, owners: AbstractSet[str], name: Optional[str] - ) -> None: + def _optimize_clusters_owners(self, owners: Set[str], name: Optional[str]) -> None: if owners or name: names = {name} for cluster_owners in self._clusters_shared_any.values(): @@ -1221,7 +1210,7 @@ def _parse_bool(value: str) -> bool: async def check_any_permissions( - request: aiohttp.web.Request, permissions: List[Permission] + request: aiohttp.web.Request, permissions: list[Permission] ) -> None: user_name = await check_authorized(request) auth_policy = request.config_dict.get(AUTZ_KEY) @@ -1241,7 +1230,7 @@ async def check_any_permissions( ) -def _permission_to_primitive(perm: Permission) -> Dict[str, str]: +def _permission_to_primitive(perm: Permission) -> dict[str, str]: return {"uri": perm.uri, "action": perm.action} diff --git a/platform_api/handlers/validators.py b/platform_api/handlers/validators.py index 1c61ff123..dc440f31a 100644 --- a/platform_api/handlers/validators.py +++ b/platform_api/handlers/validators.py @@ -1,6 +1,7 @@ import shlex +from collections.abc import Sequence from pathlib import PurePath -from typing import Any, Dict, Optional, Sequence, Set, Union +from typing import Any, Optional, Union from urllib.parse import unquote, urlsplit import trafaret as t @@ -150,9 +151,9 @@ def _validate(path_str: str) -> str: def _validate_unique_volume_paths( - volumes: Sequence[Dict[str, Any]] -) -> Sequence[Dict[str, Any]]: - paths: Set[str] = set() + volumes: Sequence[dict[str, Any]] +) -> Sequence[dict[str, Any]]: + paths: set[str] = set() for volume in volumes: path = volume["dst_path"] if path in paths: diff --git a/platform_api/kube_cluster.py b/platform_api/kube_cluster.py index 1f12fce11..1bcdabe1b 100644 --- a/platform_api/kube_cluster.py +++ b/platform_api/kube_cluster.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Sequence from contextlib import AsyncExitStack -from typing import List, Optional, Sequence +from typing import Optional import aiohttp @@ -22,7 +23,7 @@ def __init__( storage_configs: Sequence[StorageConfig], cluster_config: ClusterConfig, kube_config: KubeConfig, - trace_configs: Optional[List[aiohttp.TraceConfig]] = None, + trace_configs: Optional[list[aiohttp.TraceConfig]] = None, ) -> None: self._registry_config = registry_config self._storage_configs = storage_configs diff --git a/platform_api/log.py b/platform_api/log.py index 07d4801f3..454cdd745 100644 --- a/platform_api/log.py +++ b/platform_api/log.py @@ -1,7 +1,8 @@ import logging import time +from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, Iterator +from typing import Any logger = logging.getLogger(__name__) diff --git a/platform_api/orchestrator/base.py b/platform_api/orchestrator/base.py index 6b0501a13..1e3c893c5 100644 --- a/platform_api/orchestrator/base.py +++ b/platform_api/orchestrator/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List from .job import Job, JobStatusItem from .job_request import Disk, JobStatus @@ -20,10 +19,10 @@ async def delete_job(self, job: Job) -> JobStatus: @abstractmethod async def get_missing_secrets( - self, secret_path: str, secret_names: List[str] - ) -> List[str]: + self, secret_path: str, secret_names: list[str] + ) -> list[str]: pass @abstractmethod - async def get_missing_disks(self, disks: List[Disk]) -> List[Disk]: + async def get_missing_disks(self, disks: list[Disk]) -> list[Disk]: pass diff --git a/platform_api/orchestrator/base_postgres_storage.py b/platform_api/orchestrator/base_postgres_storage.py index 6e92c34d6..28b2ca929 100644 --- a/platform_api/orchestrator/base_postgres_storage.py +++ b/platform_api/orchestrator/base_postgres_storage.py @@ -1,7 +1,8 @@ import asyncio import sys -from contextlib import asynccontextmanager -from typing import AsyncContextManager, AsyncIterator, List, Optional +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Optional import sqlalchemy.sql as sasql from sqlalchemy.engine import Row @@ -43,7 +44,7 @@ async def _fetchrow( async def _fetch( self, query: sasql.ClauseElement, conn: Optional[AsyncConnection] = None - ) -> List[Row]: + ) -> list[Row]: if conn: result = await conn.execute(query) return result.all() @@ -61,7 +62,7 @@ async def _cursor( @asynccontextmanager async def _safe_connect( - conn_cm: AsyncContextManager[AsyncConnection], + conn_cm: AbstractAsyncContextManager[AsyncConnection], ) -> AsyncConnection: # Workaround of the SQLAlchemy bug. conn_task = asyncio.create_task(conn_cm.__aenter__()) diff --git a/platform_api/orchestrator/billing_log/service.py b/platform_api/orchestrator/billing_log/service.py index 3bc4c7710..aac189c16 100644 --- a/platform_api/orchestrator/billing_log/service.py +++ b/platform_api/orchestrator/billing_log/service.py @@ -1,7 +1,8 @@ import asyncio import logging -from contextlib import asynccontextmanager, suppress -from typing import Any, AsyncContextManager, AsyncIterator, Optional, Sequence +from collections.abc import AsyncIterator, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress +from typing import Any, Optional from aiohttp import ClientResponseError from neuro_admin_client import AdminClient @@ -30,7 +31,7 @@ def __init__( self._new_entry_notifier = new_entry self._entry_done_notifier = entry_done - self._notifier_cm: Optional[AsyncContextManager[Any]] = None + self._notifier_cm: Optional[AbstractAsyncContextManager[Any]] = None self._last_entry_id = 0 self._progress_cond = asyncio.Condition() @@ -108,7 +109,7 @@ def __init__( self._wait_timeout_s = wait_timeout_s self._task: Optional[asyncio.Task[Any]] = None - self._notifier_cm: Optional[AsyncContextManager[Any]] = None + self._notifier_cm: Optional[AbstractAsyncContextManager[Any]] = None async def __aenter__(self) -> "BillingLogWorker": self._unchecked_notify.set() # Run checks initially diff --git a/platform_api/orchestrator/billing_log/storage.py b/platform_api/orchestrator/billing_log/storage.py index 4c7825485..99699d100 100644 --- a/platform_api/orchestrator/billing_log/storage.py +++ b/platform_api/orchestrator/billing_log/storage.py @@ -1,19 +1,12 @@ import asyncio import logging from abc import ABC, abstractmethod -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass, replace from datetime import datetime from decimal import Decimal -from typing import ( - Any, - AsyncContextManager, - AsyncIterator, - Dict, - List, - Optional, - Sequence, -) +from typing import Any, Optional import sqlalchemy as sa import sqlalchemy.dialects.postgresql as sapg @@ -71,13 +64,13 @@ async def insert( @abstractmethod def entries_inserter( self, - ) -> AsyncContextManager["BillingLogStorage.EntriesInserter"]: + ) -> AbstractAsyncContextManager["BillingLogStorage.EntriesInserter"]: pass @abstractmethod def iter_entries( self, *, with_ids_greater: int = 0, limit: Optional[int] = None - ) -> AsyncContextManager[AsyncIterator[BillingLogEntry]]: + ) -> AbstractAsyncContextManager[AsyncIterator[BillingLogEntry]]: pass @abstractmethod @@ -91,7 +84,7 @@ async def get_last_entry_id(self, job_id: Optional[str] = None) -> int: class InMemoryBillingLogStorage(BillingLogStorage): def __init__(self) -> None: - self._entries: List[BillingLogEntry] = [] + self._entries: list[BillingLogEntry] = [] self._sync_record: Optional[BillingLogSyncRecord] = None self._inserter_lock = asyncio.Lock() self._dropped_cnt: int = 0 @@ -189,7 +182,7 @@ def __init__( # Parsing/serialization - def _log_entry_to_values(self, entry: BillingLogEntry) -> Dict[str, Any]: + def _log_entry_to_values(self, entry: BillingLogEntry) -> dict[str, Any]: return { "job_id": entry.job_id, "payload": { @@ -213,13 +206,13 @@ def _record_to_log_entry(self, record: Row) -> BillingLogEntry: def _sync_record_to_values( self, sync_record: BillingLogSyncRecord - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return { "type": self.BILLING_SYNC_RECORD_TYPE, "last_entry_id": sync_record.last_entry_id, } - def _record_to_sync_record(self, record: Dict[str, Any]) -> BillingLogSyncRecord: + def _record_to_sync_record(self, record: dict[str, Any]) -> BillingLogSyncRecord: assert record["type"] == self.BILLING_SYNC_RECORD_TYPE return BillingLogSyncRecord(last_entry_id=record["last_entry_id"]) diff --git a/platform_api/orchestrator/job.py b/platform_api/orchestrator/job.py index e0c826b0a..7f7b8c88c 100644 --- a/platform_api/orchestrator/job.py +++ b/platform_api/orchestrator/job.py @@ -1,10 +1,11 @@ import enum import logging +from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import partial -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence +from typing import Any, Optional import iso8601 from yarl import URL @@ -100,7 +101,7 @@ def create( return cls(status=status, transition_time=transition_time, **kwargs) @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "JobStatusItem": + def from_primitive(cls, payload: dict[str, Any]) -> "JobStatusItem": status = JobStatus(payload["status"]) transition_time = iso8601.parse_date(payload["transition_time"]) return cls( @@ -111,8 +112,8 @@ def from_primitive(cls, payload: Dict[str, Any]) -> "JobStatusItem": exit_code=payload.get("exit_code"), ) - def to_primitive(self) -> Dict[str, Any]: - result: Dict[str, Any] = { + def to_primitive(self) -> dict[str, Any]: + result: dict[str, Any] = { "status": str(self.status.value), "transition_time": self.transition_time.isoformat(), "reason": self.reason, @@ -124,7 +125,7 @@ def to_primitive(self) -> Dict[str, Any]: class JobStatusHistory: - def __init__(self, items: List[JobStatusItem]) -> None: + def __init__(self, items: list[JobStatusItem]) -> None: assert items, "JobStatusHistory should contain at least one entry" self._items = items @@ -430,7 +431,7 @@ def should_be_deleted( ) ) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: if not self.allow_empty_cluster_name and not self.cluster_name: raise RuntimeError( "empty cluster name must be already replaced with `default`" @@ -481,7 +482,7 @@ def to_primitive(self) -> Dict[str, Any]: @classmethod def from_primitive( cls, - payload: Dict[str, Any], + payload: dict[str, Any], orphaned_job_owner: str = DEFAULT_ORPHANED_JOB_OWNER, ) -> "JobRecord": request = JobRequest.from_primitive(payload["request"]) @@ -522,7 +523,7 @@ def from_primitive( @staticmethod def create_status_history_from_primitive( - job_id: str, payload: Dict[str, Any] + job_id: str, payload: dict[str, Any] ) -> JobStatusHistory: if "statuses" in payload: # already migrated to history @@ -878,14 +879,14 @@ def total_price_credits(self) -> Decimal: def org_name(self) -> Optional[str]: return self._record.org_name - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return self._record.to_primitive() @classmethod def from_primitive( cls, orchestrator_config: OrchestratorConfig, - payload: Dict[str, Any], + payload: dict[str, Any], ) -> "Job": record = JobRecord.from_primitive(payload) return cls( diff --git a/platform_api/orchestrator/job_policy_enforcer.py b/platform_api/orchestrator/job_policy_enforcer.py index 454660373..56410388f 100644 --- a/platform_api/orchestrator/job_policy_enforcer.py +++ b/platform_api/orchestrator/job_policy_enforcer.py @@ -4,20 +4,10 @@ import logging import uuid from collections import defaultdict +from collections.abc import Callable, Iterable, Mapping from datetime import datetime, timedelta, timezone from decimal import Decimal -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - TypeVar, -) +from typing import Any, Optional, TypeVar from aiohttp import ClientResponseError from neuro_admin_client import AdminClient, ClusterUser, OrgCluster @@ -61,7 +51,7 @@ def __init__( self._admin_client = admin_client self._notifications_client = notifications_client self._threshold = notification_threshold - self._sent: Dict[Tuple[str, str], Optional[Decimal]] = defaultdict(lambda: None) + self._sent: dict[tuple[str, str], Optional[Decimal]] = defaultdict(lambda: None) async def _notify_user_if_needed( self, @@ -89,7 +79,7 @@ async def _notify_user_if_needed( @trace async def enforce(self) -> None: - user_to_clusters: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set) + user_to_clusters: dict[str, set[tuple[str, Optional[str]]]] = defaultdict(set) job_filter = JobFilter( statuses={JobStatus(item) for item in JobStatus.active_values()} ) @@ -102,7 +92,7 @@ async def enforce(self) -> None: ) async def _enforce_for_user( - self, username: str, clusters_and_orgs: Set[Tuple[str, Optional[str]]] + self, username: str, clusters_and_orgs: set[tuple[str, Optional[str]]] ) -> None: base_name = username.split("/", 1)[0] # SA inherit balance from main user _, cluster_users = await self._admin_client.get_user_with_clusters(base_name) @@ -161,7 +151,7 @@ def __init__(self, service: JobsService, admin_client: AdminClient): def _groupby( self, it: Iterable[_T], key: Callable[[_T], _K] - ) -> Mapping[_K, List[_T]]: + ) -> Mapping[_K, list[_T]]: res = defaultdict(list) for item in it: res[key(item)].append(item) @@ -331,7 +321,7 @@ async def enforce(self) -> None: class JobPolicyEnforcePoller: def __init__( - self, config: JobPolicyEnforcerConfig, enforcers: List[JobPolicyEnforcer] + self, config: JobPolicyEnforcerConfig, enforcers: list[JobPolicyEnforcer] ) -> None: self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self._enforcers = enforcers diff --git a/platform_api/orchestrator/job_request.py b/platform_api/orchestrator/job_request.py index 7821824e7..067669cff 100644 --- a/platform_api/orchestrator/job_request.py +++ b/platform_api/orchestrator/job_request.py @@ -2,9 +2,10 @@ import shlex import uuid from collections import defaultdict +from collections.abc import Sequence from dataclasses import asdict, dataclass, field from pathlib import PurePath -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Optional, Union from yarl import URL @@ -44,15 +45,15 @@ def create( return cls(uri=URL(uri), dst_path=dst_path, read_only=read_only) @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "ContainerVolume": + def from_primitive(cls, payload: dict[str, Any]) -> "ContainerVolume": return cls( uri=URL(payload.get("uri", "")), dst_path=PurePath(payload["dst_path"]), read_only=payload["read_only"], ) - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = asdict(self) + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = asdict(self) payload["uri"] = str(payload["uri"]) payload["dst_path"] = str(payload["dst_path"]) return payload @@ -98,14 +99,14 @@ def create( return cls(disk=Disk.create(uri), dst_path=dst_path, read_only=read_only) @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "DiskContainerVolume": + def from_primitive(cls, payload: dict[str, Any]) -> "DiskContainerVolume": return cls.create( uri=payload["src_disk_uri"], dst_path=PurePath(payload["dst_path"]), read_only=payload["read_only"], ) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "src_disk_uri": str(self.to_uri()), "dst_path": str(self.dst_path), @@ -156,12 +157,12 @@ def create(cls, uri: str, dst_path: PurePath) -> "SecretContainerVolume": return cls(secret=Secret.create(uri), dst_path=dst_path) @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "SecretContainerVolume": + def from_primitive(cls, payload: dict[str, Any]) -> "SecretContainerVolume": return cls.create( uri=payload["src_secret_uri"], dst_path=PurePath(payload["dst_path"]) ) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "src_secret_uri": str(self.to_uri()), "dst_path": str(self.dst_path), @@ -174,10 +175,10 @@ class ContainerTPUResource: software_version: str @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "ContainerTPUResource": + def from_primitive(cls, payload: dict[str, Any]) -> "ContainerTPUResource": return cls(type=payload["type"], software_version=payload["software_version"]) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return {"type": self.type, "software_version": self.software_version} @@ -191,7 +192,7 @@ class ContainerResources: tpu: Optional[ContainerTPUResource] = None @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "ContainerResources": + def from_primitive(cls, payload: dict[str, Any]) -> "ContainerResources": tpu = None if payload.get("tpu"): tpu = ContainerTPUResource.from_primitive(payload["tpu"]) @@ -204,8 +205,8 @@ def from_primitive(cls, payload: Dict[str, Any]) -> "ContainerResources": tpu=tpu, ) - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"cpu": self.cpu, "memory_mb": self.memory_mb} + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = {"cpu": self.cpu, "memory_mb": self.memory_mb} if self.gpu is not None: payload["gpu"] = self.gpu payload["gpu_model_id"] = self.gpu_model_id @@ -294,14 +295,14 @@ class ContainerHTTPServer: requires_auth: bool = False @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "ContainerHTTPServer": + def from_primitive(cls, payload: dict[str, Any]) -> "ContainerHTTPServer": return cls( port=payload["port"], health_check_path=payload.get("health_check_path") or cls.health_check_path, requires_auth=payload.get("requires_auth", cls.requires_auth), ) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return asdict(self) @@ -311,11 +312,11 @@ class Container: resources: ContainerResources entrypoint: Optional[str] = None command: Optional[str] = None - env: Dict[str, str] = field(default_factory=dict) - volumes: List[ContainerVolume] = field(default_factory=list) - secret_env: Dict[str, Secret] = field(default_factory=dict) - secret_volumes: List[SecretContainerVolume] = field(default_factory=list) - disk_volumes: List[DiskContainerVolume] = field(default_factory=list) + env: dict[str, str] = field(default_factory=dict) + volumes: list[ContainerVolume] = field(default_factory=list) + secret_env: dict[str, Secret] = field(default_factory=dict) + secret_volumes: list[SecretContainerVolume] = field(default_factory=list) + disk_volumes: list[DiskContainerVolume] = field(default_factory=list) http_server: Optional[ContainerHTTPServer] = None tty: bool = False working_dir: Optional[str] = None @@ -332,19 +333,19 @@ def to_image_uri(self, registry_host: str, cluster_name: str) -> URL: assert cluster_name return URL.build(scheme="image", host=cluster_name) / path - def get_secrets(self) -> List[Secret]: + def get_secrets(self) -> list[Secret]: return list( {*self.secret_env.values(), *(v.secret for v in self.secret_volumes)} ) - def get_path_to_secrets(self) -> Dict[str, List[Secret]]: - path_to_secrets: Dict[str, List[Secret]] = defaultdict(list) + def get_path_to_secrets(self) -> dict[str, list[Secret]]: + path_to_secrets: dict[str, list[Secret]] = defaultdict(list) for secret in self.get_secrets(): path_to_secrets[secret.path].append(secret) return path_to_secrets - def get_path_to_secret_volumes(self) -> Dict[str, List[SecretContainerVolume]]: - user_volumes: Dict[str, List[SecretContainerVolume]] = defaultdict(list) + def get_path_to_secret_volumes(self) -> dict[str, list[SecretContainerVolume]]: + user_volumes: dict[str, list[SecretContainerVolume]] = defaultdict(list) for volume in self.secret_volumes: user_volumes[volume.secret.path].append(volume) return user_volumes @@ -366,20 +367,20 @@ def health_check_path(self) -> str: return self.http_server.health_check_path return ContainerHTTPServer.health_check_path - def _parse_command(self, command: str) -> List[str]: + def _parse_command(self, command: str) -> list[str]: try: return shlex.split(command) except ValueError: raise JobError("invalid command format") @property - def entrypoint_list(self) -> List[str]: + def entrypoint_list(self) -> list[str]: if self.entrypoint: return self._parse_command(self.entrypoint) return [] @property - def command_list(self) -> List[str]: + def command_list(self) -> list[str]: if self.command: return self._parse_command(self.command) return [] @@ -393,7 +394,7 @@ def requires_http_auth(self) -> bool: return bool(self.http_server and self.http_server.requires_auth) @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "Container": + def from_primitive(cls, payload: dict[str, Any]) -> "Container": kwargs = payload.copy() kwargs["resources"] = ContainerResources.from_primitive(kwargs["resources"]) kwargs["volumes"] = [ @@ -431,8 +432,8 @@ def from_primitive(cls, payload: Dict[str, Any]) -> "Container": return cls(**kwargs) - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = asdict(self) + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = asdict(self) payload["resources"] = self.resources.to_primitive() payload["volumes"] = [volume.to_primitive() for volume in self.volumes] @@ -483,12 +484,12 @@ def create( ) @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "JobRequest": + def from_primitive(cls, payload: dict[str, Any]) -> "JobRequest": kwargs = payload.copy() kwargs["container"] = Container.from_primitive(kwargs["container"]) return cls(**kwargs) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: result = {"job_id": self.job_id, "container": self.container.to_primitive()} if self.description: result["description"] = self.description @@ -527,15 +528,15 @@ def is_finished(self) -> bool: return self in (self.SUCCEEDED, self.FAILED, self.CANCELLED) @classmethod - def values(cls) -> List[str]: + def values(cls) -> list[str]: return [item.value for item in cls] @classmethod - def active_values(cls) -> List[str]: + def active_values(cls) -> list[str]: return [item.value for item in cls if not item.is_finished] @classmethod - def finished_values(cls) -> List[str]: + def finished_values(cls) -> list[str]: return [item.value for item in cls if item.is_finished] def __repr__(self) -> str: diff --git a/platform_api/orchestrator/jobs_poller.py b/platform_api/orchestrator/jobs_poller.py index 70eb087d7..b956c0c5b 100644 --- a/platform_api/orchestrator/jobs_poller.py +++ b/platform_api/orchestrator/jobs_poller.py @@ -1,8 +1,9 @@ import asyncio import logging +from collections.abc import Mapping from datetime import timedelta from pathlib import PurePath -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Optional import aiohttp from iso8601 import iso8601 @@ -50,12 +51,12 @@ def _parse_container_volume(data: Mapping[str, Any]) -> ContainerVolume: read_only=bool(data.get("read_only")), ) - def _parse_secret_volume(payload: Dict[str, Any]) -> SecretContainerVolume: + def _parse_secret_volume(payload: dict[str, Any]) -> SecretContainerVolume: return SecretContainerVolume.create( uri=payload["src_secret_uri"], dst_path=PurePath(payload["dst_path"]) ) - def _parse_disk_volume(payload: Dict[str, Any]) -> DiskContainerVolume: + def _parse_disk_volume(payload: dict[str, Any]) -> DiskContainerVolume: return DiskContainerVolume.create( uri=payload["src_disk_uri"], dst_path=PurePath(payload["dst_path"]), @@ -150,7 +151,7 @@ def __init__( url: URL, token: str, cluster_name: str, - trace_configs: Optional[List[aiohttp.TraceConfig]] = None, + trace_configs: Optional[list[aiohttp.TraceConfig]] = None, ): self._base_url = url self._token = token @@ -160,7 +161,7 @@ def __init__( async def init(self) -> None: if self._client: return - headers: Dict[str, str] = {} + headers: dict[str, str] = {} if self._token: headers["Authorization"] = f"Bearer {self._token}" self._client = aiohttp.ClientSession( @@ -179,7 +180,7 @@ async def __aenter__(self) -> "HttpJobsPollerApi": async def __aexit__(self, *args: Any) -> None: await self.close() - async def get_unfinished_jobs(self) -> List[JobRecord]: + async def get_unfinished_jobs(self) -> list[JobRecord]: assert self._client url = self._base_url / "jobs" params: MultiDict[Any] = MultiDict() @@ -190,7 +191,7 @@ async def get_unfinished_jobs(self) -> List[JobRecord]: payload = await resp.json() return [job_response_to_job_record(item) for item in payload["jobs"]] - async def get_jobs_for_deletion(self, *, delay: timedelta) -> List[JobRecord]: + async def get_jobs_for_deletion(self, *, delay: timedelta) -> list[JobRecord]: assert self._client url = self._base_url / "jobs" params: MultiDict[Any] = MultiDict() diff --git a/platform_api/orchestrator/jobs_service.py b/platform_api/orchestrator/jobs_service.py index 1b7c2a8b8..49c40ed81 100644 --- a/platform_api/orchestrator/jobs_service.py +++ b/platform_api/orchestrator/jobs_service.py @@ -2,11 +2,12 @@ import json import logging from collections import defaultdict +from collections.abc import AsyncIterator, Iterable, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, replace from datetime import datetime, timedelta from decimal import Decimal -from typing import AsyncIterator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Optional, Union from aiohttp import ClientResponseError from neuro_admin_client import AdminClient, ClusterUser, OrgCluster @@ -82,7 +83,7 @@ def create_for_org(cls, org: str) -> "NoCreditsError": class UserClusterConfig: config: ClusterConfig # None value means the direct access to cluster without any or: - orgs: List[Optional[str]] + orgs: list[Optional[str]] class JobsService: @@ -231,7 +232,7 @@ async def create_job( schedule_timeout: Optional[float] = None, max_run_time_minutes: Optional[int] = None, restart_policy: JobRestartPolicy = JobRestartPolicy.NEVER, - ) -> Tuple[Job, Status]: + ) -> tuple[Job, Status]: base_name = user.name.split("/", 1)[ 0 ] # SA has access to same clusters as a user @@ -421,7 +422,7 @@ async def iter_all_jobs( async def get_all_jobs( self, job_filter: Optional[JobFilter] = None, *, reverse: bool = False - ) -> List[Job]: + ) -> list[Job]: async with self.iter_all_jobs(job_filter, reverse=reverse) as it: return [job async for job in it] @@ -436,13 +437,13 @@ async def get_job_by_name(self, job_name: str, owner: AuthUser) -> Job: async def get_jobs_by_ids( self, job_ids: Iterable[str], job_filter: Optional[JobFilter] = None - ) -> List[Job]: + ) -> list[Job]: records = await self._jobs_storage.get_jobs_by_ids( job_ids, job_filter=job_filter ) return [await self._get_cluster_job(record) for record in records] - async def get_user_cluster_configs(self, user: AuthUser) -> List[UserClusterConfig]: + async def get_user_cluster_configs(self, user: AuthUser) -> list[UserClusterConfig]: configs = [] base_name = user.name.split("/", 1)[0] # SA has access to same clusters as user cluster_to_orgs = defaultdict(list) @@ -539,7 +540,7 @@ async def get_not_billed_jobs(self) -> AsyncIterator[Job]: async def get_job_ids_for_drop( self, *, delay: timedelta, limit: Optional[int] = None - ) -> List[str]: + ) -> list[str]: return [ record.id for record in await self._jobs_storage.get_jobs_for_drop( diff --git a/platform_api/orchestrator/jobs_storage/base.py b/platform_api/orchestrator/jobs_storage/base.py index faf5f66f7..21f11a488 100644 --- a/platform_api/orchestrator/jobs_storage/base.py +++ b/platform_api/orchestrator/jobs_storage/base.py @@ -1,19 +1,10 @@ import logging from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterable, Set +from contextlib import AbstractAsyncContextManager from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import ( - AbstractSet, - AsyncContextManager, - AsyncIterator, - Dict, - Iterable, - List, - Optional, - Set, - Type, - cast, -) +from typing import Optional, cast from platform_api.orchestrator.job import JobRecord from platform_api.orchestrator.job_request import JobStatus @@ -38,25 +29,23 @@ def __init__(self, job_name: str, job_owner: str, found_job_id: str): ) -ClusterOwnerNameSet = Dict[str, Dict[str, AbstractSet[str]]] +ClusterOwnerNameSet = dict[str, dict[str, Set[str]]] @dataclass(frozen=True) class JobFilter: - statuses: AbstractSet[JobStatus] = field( - default_factory=cast(Type[Set[JobStatus]], set) - ) + statuses: Set[JobStatus] = field(default_factory=cast(type[Set[JobStatus]], set)) clusters: ClusterOwnerNameSet = field( - default_factory=cast(Type[ClusterOwnerNameSet], dict) + default_factory=cast(type[ClusterOwnerNameSet], dict) ) - orgs: AbstractSet[Optional[str]] = field( - default_factory=cast(Type[Set[Optional[str]]], set) + orgs: Set[Optional[str]] = field( + default_factory=cast(type[Set[Optional[str]]], set) ) - owners: AbstractSet[str] = field(default_factory=cast(Type[Set[str]], set)) - base_owners: AbstractSet[str] = field(default_factory=cast(Type[Set[str]], set)) - tags: Set[str] = field(default_factory=cast(Type[Set[str]], set)) + owners: Set[str] = field(default_factory=cast(type[Set[str]], set)) + base_owners: Set[str] = field(default_factory=cast(type[Set[str]], set)) + tags: Set[str] = field(default_factory=cast(type[Set[str]], set)) name: Optional[str] = None - ids: AbstractSet[str] = field(default_factory=cast(Type[Set[str]], set)) + ids: Set[str] = field(default_factory=cast(type[Set[str]], set)) since: datetime = datetime(1, 1, 1, tzinfo=timezone.utc) until: datetime = datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) materialized: Optional[bool] = None @@ -85,7 +74,7 @@ def check(self, job: JobRecord) -> bool: return False if self.ids and job.id not in self.ids: return False - if self.tags and not self.tags.issubset(job.tags): + if self.tags and not self.tags <= set(job.tags): return False created_at = job.status_history.created_at if not self.since <= created_at <= self.until: @@ -103,7 +92,7 @@ def check(self, job: JobRecord) -> bool: class JobsStorage(ABC): @abstractmethod - def try_create_job(self, job: JobRecord) -> AsyncContextManager[JobRecord]: + def try_create_job(self, job: JobRecord) -> AbstractAsyncContextManager[JobRecord]: pass @abstractmethod @@ -119,7 +108,7 @@ async def drop_job(self, job_id: str) -> None: pass @abstractmethod - def try_update_job(self, job_id: str) -> AsyncContextManager[JobRecord]: + def try_update_job(self, job_id: str) -> AbstractAsyncContextManager[JobRecord]: pass @abstractmethod @@ -129,13 +118,13 @@ def iter_all_jobs( *, reverse: bool = False, limit: Optional[int] = None, - ) -> AsyncContextManager[AsyncIterator[JobRecord]]: + ) -> AbstractAsyncContextManager[AsyncIterator[JobRecord]]: pass @abstractmethod async def get_jobs_by_ids( self, job_ids: Iterable[str], job_filter: Optional[JobFilter] = None - ) -> List[JobRecord]: + ) -> list[JobRecord]: pass # Only used in tests @@ -144,17 +133,17 @@ async def get_all_jobs( job_filter: Optional[JobFilter] = None, reverse: bool = False, limit: Optional[int] = None, - ) -> List[JobRecord]: + ) -> list[JobRecord]: async with self.iter_all_jobs(job_filter, reverse=reverse, limit=limit) as it: return [job async for job in it] # Only used in tests - async def get_running_jobs(self) -> List[JobRecord]: + async def get_running_jobs(self) -> list[JobRecord]: filt = JobFilter(statuses={JobStatus.RUNNING}) return await self.get_all_jobs(filt) # Only used in tests - async def get_unfinished_jobs(self) -> List[JobRecord]: + async def get_unfinished_jobs(self) -> list[JobRecord]: filt = JobFilter( statuses={JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUSPENDED} ) @@ -163,11 +152,11 @@ async def get_unfinished_jobs(self) -> List[JobRecord]: @abstractmethod async def get_jobs_for_deletion( self, *, delay: timedelta = timedelta() - ) -> List[JobRecord]: + ) -> list[JobRecord]: pass @abstractmethod async def get_jobs_for_drop( self, *, delay: timedelta = timedelta(), limit: Optional[int] = None - ) -> List[JobRecord]: + ) -> list[JobRecord]: pass diff --git a/platform_api/orchestrator/jobs_storage/in_memory.py b/platform_api/orchestrator/jobs_storage/in_memory.py index 5d003a88a..ed4527394 100644 --- a/platform_api/orchestrator/jobs_storage/in_memory.py +++ b/platform_api/orchestrator/jobs_storage/in_memory.py @@ -1,7 +1,8 @@ import json +from collections.abc import AsyncIterator, Iterable from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone -from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple +from typing import Optional from platform_api.orchestrator.job import JobRecord from platform_api.orchestrator.job_request import JobError @@ -13,11 +14,11 @@ class InMemoryJobsStorage(JobsStorage): def __init__(self) -> None: # job_id to job mapping: - self._job_records: Dict[str, str] = {} + self._job_records: dict[str, str] = {} # job_name+owner to job_id mapping: - self._last_alive_job_records: Dict[Tuple[str, str], str] = {} + self._last_alive_job_records: dict[tuple[str, str], str] = {} # owner to job tags mapping: - self._owner_to_tags: Dict[str, List[str]] = {} + self._owner_to_tags: dict[str, list[str]] = {} @asynccontextmanager async def try_create_job(self, job: JobRecord) -> AsyncIterator[JobRecord]: @@ -91,7 +92,7 @@ async def iter_all_jobs( async def get_jobs_by_ids( self, job_ids: Iterable[str], job_filter: Optional[JobFilter] = None - ) -> List[JobRecord]: + ) -> list[JobRecord]: jobs = [] for job_id in job_ids: try: @@ -105,7 +106,7 @@ async def get_jobs_by_ids( async def get_jobs_for_deletion( self, *, delay: timedelta = timedelta() - ) -> List[JobRecord]: + ) -> list[JobRecord]: async with self.iter_all_jobs() as it: return [job async for job in it if job.should_be_deleted(delay=delay)] @@ -114,7 +115,7 @@ async def get_jobs_for_drop( *, delay: timedelta = timedelta(), limit: Optional[int] = None, - ) -> List[JobRecord]: + ) -> list[JobRecord]: now = datetime.now(timezone.utc) jobs = [] async with self.iter_all_jobs() as it: diff --git a/platform_api/orchestrator/jobs_storage/postgres.py b/platform_api/orchestrator/jobs_storage/postgres.py index c941a8e2c..12937c29b 100644 --- a/platform_api/orchestrator/jobs_storage/postgres.py +++ b/platform_api/orchestrator/jobs_storage/postgres.py @@ -1,16 +1,8 @@ +from collections.abc import AsyncIterator, Iterable, Mapping, Set from contextlib import asynccontextmanager from dataclasses import dataclass, replace from datetime import datetime, timedelta, timezone -from typing import ( - AbstractSet, - Any, - AsyncIterator, - Dict, - Iterable, - List, - Mapping, - Optional, -) +from typing import Any, Optional import sqlalchemy as sa import sqlalchemy.dialects.postgresql as sapg @@ -74,7 +66,7 @@ def __init__(self, engine: AsyncEngine, tables: Optional[JobTables] = None) -> N # Parsing/serialization - def _job_to_values(self, job: JobRecord) -> Dict[str, Any]: + def _job_to_values(self, job: JobRecord) -> dict[str, Any]: payload = job.to_primitive() return { "id": payload.pop("id"), @@ -260,7 +252,7 @@ async def iter_all_jobs( async def get_jobs_by_ids( self, job_ids: Iterable[str], job_filter: Optional[JobFilter] = None - ) -> List[JobRecord]: + ) -> list[JobRecord]: if job_filter is None: job_filter = JobFilter() if job_filter.ids: @@ -279,7 +271,7 @@ async def get_jobs_by_ids( async def get_jobs_for_deletion( self, *, delay: timedelta = timedelta() - ) -> List[JobRecord]: + ) -> list[JobRecord]: job_filter = JobFilter( statuses={JobStatus(item) for item in JobStatus.finished_values()}, materialized=True, @@ -296,7 +288,7 @@ async def get_jobs_for_drop( *, delay: timedelta = timedelta(), limit: Optional[int] = None, - ) -> List[JobRecord]: + ) -> list[JobRecord]: job_filter = JobFilter( statuses={JobStatus(item) for item in JobStatus.finished_values()}, materialized=False, @@ -315,16 +307,16 @@ async def get_jobs_for_drop( class JobFilterClauseBuilder: def __init__(self, tables: JobTables): - self._clauses: List[sasql.ClauseElement] = [] + self._clauses: list[sasql.ClauseElement] = [] self._tables = tables - def filter_statuses(self, statuses: AbstractSet[JobStatus]) -> None: + def filter_statuses(self, statuses: Set[JobStatus]) -> None: self._clauses.append(self._tables.jobs.c.status.in_(statuses)) - def filter_owners(self, owners: AbstractSet[str]) -> None: + def filter_owners(self, owners: Set[str]) -> None: self._clauses.append(self._tables.jobs.c.owner.in_(owners)) - def filter_base_owners(self, base_owners: AbstractSet[str]) -> None: + def filter_base_owners(self, base_owners: Set[str]) -> None: self._clauses.append( func.split_part(self._tables.jobs.c.owner, "/", 1).in_(base_owners) ) @@ -355,7 +347,7 @@ def filter_clusters(self, clusters: ClusterOwnerNameSet) -> None: ) self._clauses.append(or_(*cluster_clauses)) - def filter_orgs(self, orgs: AbstractSet[Optional[str]]) -> None: + def filter_orgs(self, orgs: Set[Optional[str]]) -> None: not_null_orgs = [org for org in orgs if org is not None] or_clauses = [] if not_null_orgs: @@ -367,10 +359,10 @@ def filter_orgs(self, orgs: AbstractSet[Optional[str]]) -> None: def filter_name(self, name: str) -> None: self._clauses.append(self._tables.jobs.c.name == name) - def filter_ids(self, ids: AbstractSet[str]) -> None: + def filter_ids(self, ids: Set[str]) -> None: self._clauses.append(self._tables.jobs.c.id.in_(ids)) - def filter_tags(self, tags: AbstractSet[str]) -> None: + def filter_tags(self, tags: Set[str]) -> None: self._clauses.append(self._tables.jobs.c.tags.contains(list(tags))) def filter_since(self, since: datetime) -> None: diff --git a/platform_api/orchestrator/kube_client.py b/platform_api/orchestrator/kube_client.py index 246d6fe2c..99b942cd6 100644 --- a/platform_api/orchestrator/kube_client.py +++ b/platform_api/orchestrator/kube_client.py @@ -6,28 +6,15 @@ import re import ssl from base64 import b64encode +from collections import defaultdict +from collections.abc import AsyncIterator, Callable, Iterable, Sequence from contextlib import suppress from dataclasses import dataclass, field, replace from datetime import datetime from enum import Enum from pathlib import Path, PurePath from types import TracebackType -from typing import ( - Any, - AsyncIterator, - Callable, - ClassVar, - DefaultDict, - Dict, - Iterable, - List, - NoReturn, - Optional, - Sequence, - Tuple, - Type, - Union, -) +from typing import Any, ClassVar, NoReturn, Optional, Union from urllib.parse import urlsplit import aiohttp @@ -79,7 +66,7 @@ class NotFoundException(StatusException): pass -def _raise_status_job_exception(pod: Dict[str, Any], job_id: Optional[str]) -> NoReturn: +def _raise_status_job_exception(pod: dict[str, Any], job_id: Optional[str]) -> NoReturn: if pod["code"] == 409: raise AlreadyExistsException(pod.get("reason", "job already exists")) elif pod["code"] == 404: @@ -97,7 +84,7 @@ class Volume(metaclass=abc.ABCMeta): def create_mount(self, container_volume: ContainerVolume) -> "VolumeMount": raise NotImplementedError("Cannot create mount for abstract Volume type.") - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: raise NotImplementedError @@ -117,7 +104,7 @@ def create_mount(self, container_volume: ContainerVolume) -> "VolumeMount": @dataclass(frozen=True) class HostVolume(PathVolume): - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "name": self.name, "hostPath": {"path": str(self.path), "type": "Directory"}, @@ -126,7 +113,7 @@ def to_primitive(self) -> Dict[str, Any]: @dataclass(frozen=True) class SharedMemoryVolume(Volume): - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return {"name": self.name, "emptyDir": {"medium": "Memory"}} def create_mount(self, container_volume: ContainerVolume) -> "VolumeMount": @@ -142,7 +129,7 @@ def create_mount(self, container_volume: ContainerVolume) -> "VolumeMount": class NfsVolume(PathVolume): server: str - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "name": self.name, "nfs": {"server": self.server, "path": str(self.path)}, @@ -153,7 +140,7 @@ def to_primitive(self) -> Dict[str, Any]: class PVCVolume(PathVolume): claim_name: str - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "name": self.name, "persistentVolumeClaim": {"claimName": self.claim_name}, @@ -169,7 +156,7 @@ class SecretEnvVar: def create(cls, name: str, secret: Secret) -> "SecretEnvVar": return cls(name=name, secret=secret) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "name": self.name, "valueFrom": { @@ -188,7 +175,7 @@ class VolumeMount: sub_path: PurePath = PurePath("") read_only: bool = False - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: sub_path = str(self.sub_path) raw = { "name": self.volume.name, @@ -212,7 +199,7 @@ def create_secret_mount(self, sec_volume: SecretContainerVolume) -> "VolumeMount read_only=True, ) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "name": self.name, "secret": {"secretName": self.k8s_secret_name, "defaultMode": 0o400}, @@ -230,7 +217,7 @@ def create_disk_mount(self, disk_volume: DiskContainerVolume) -> "VolumeMount": read_only=disk_volume.read_only, ) - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "name": self.name, "persistentVolumeClaim": {"claimName": self.claim_name}, @@ -271,8 +258,8 @@ def memory_request_mib(self) -> str: def tpu_key(self) -> str: return self.tpu_key_template.format(version=self.tpu_version) - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = { + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = { "requests": {"cpu": self.cpu_mcores, "memory": self.memory_mib}, "limits": {"cpu": self.cpu_mcores, "memory": self.memory_mib}, } @@ -287,7 +274,7 @@ def to_primitive(self) -> Dict[str, Any]: return payload @classmethod - def _parse_tpu_resource(cls, tpu: ContainerTPUResource) -> Tuple[str, int]: + def _parse_tpu_resource(cls, tpu: ContainerTPUResource) -> tuple[str, int]: try: tpu_version, tpu_cores = tpu.type.rsplit("-", 1) return tpu_version, int(tpu_cores) @@ -296,7 +283,7 @@ def _parse_tpu_resource(cls, tpu: ContainerTPUResource) -> Tuple[str, int]: @classmethod def from_container_resources(cls, resources: ContainerResources) -> "Resources": - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} if resources.tpu: kwargs["tpu_version"], kwargs["tpu_cores"] = cls._parse_tpu_resource( resources.tpu @@ -315,24 +302,24 @@ class Service: name: str target_port: Optional[int] uid: Optional[str] = None - selector: Dict[str, str] = field(default_factory=dict) + selector: dict[str, str] = field(default_factory=dict) port: int = 80 service_type: ServiceType = ServiceType.CLUSTER_IP cluster_ip: Optional[str] = None - labels: Dict[str, str] = field(default_factory=dict) + labels: dict[str, str] = field(default_factory=dict) def _add_port_map( self, port: Optional[int], target_port: Optional[int], port_name: str, - ports: List[Dict[str, Any]], + ports: list[dict[str, Any]], ) -> None: if target_port: ports.append({"port": port, "targetPort": target_port, "name": port_name}) - def to_primitive(self) -> Dict[str, Any]: - service_descriptor: Dict[str, Any] = { + def to_primitive(self) -> dict[str, Any]: + service_descriptor: dict[str, Any] = { "metadata": {"name": self.name}, "spec": { "type": self.service_type.value, @@ -376,15 +363,15 @@ def make_named(self, name: str) -> "Service": @classmethod def _find_port_by_name( - cls, name: str, port_mappings: List[Dict[str, Any]] - ) -> Dict[str, Any]: + cls, name: str, port_mappings: list[dict[str, Any]] + ) -> dict[str, Any]: for port_mapping in port_mappings: if port_mapping.get("name", None) == name: return port_mapping return {} @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "Service": + def from_primitive(cls, payload: dict[str, Any]) -> "Service": http_payload = cls._find_port_by_name("http", payload["spec"]["ports"]) service_type = payload["spec"].get("type", Service.service_type.value) return cls( @@ -406,7 +393,7 @@ class IngressRule: service_port: Optional[int] = None @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "IngressRule": + def from_primitive(cls, payload: dict[str, Any]) -> "IngressRule": http_paths = payload.get("http", {}).get("paths", []) http_path = http_paths[0] if http_paths else {} backend = http_path.get("backend", {}) @@ -418,8 +405,8 @@ def from_primitive(cls, payload: Dict[str, Any]) -> "IngressRule": service_port=service_port, ) - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"host": self.host} + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = {"host": self.host} if self.service_name: payload["http"] = { "paths": [ @@ -441,12 +428,12 @@ def from_service(cls, host: str, service: Service) -> "IngressRule": @dataclass(frozen=True) class Ingress: name: str - rules: List[IngressRule] = field(default_factory=list) - annotations: Dict[str, str] = field(default_factory=dict) - labels: Dict[str, str] = field(default_factory=dict) + rules: list[IngressRule] = field(default_factory=list) + annotations: dict[str, str] = field(default_factory=dict) + labels: dict[str, str] = field(default_factory=dict) - def to_primitive(self) -> Dict[str, Any]: - rules: List[Any] = [rule.to_primitive() for rule in self.rules] or [None] + def to_primitive(self) -> dict[str, Any]: + rules: list[Any] = [rule.to_primitive() for rule in self.rules] or [None] metadata = {"name": self.name, "annotations": self.annotations} if self.labels: metadata["labels"] = self.labels.copy() @@ -454,7 +441,7 @@ def to_primitive(self) -> Dict[str, Any]: return primitive @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "Ingress": + def from_primitive(cls, payload: dict[str, Any]) -> "Ingress": # TODO (A Danshyn 06/13/18): should be refactored along with PodStatus kind = payload["kind"] if kind == "Ingress": @@ -515,7 +502,7 @@ def _build_json(self) -> str: ).encode("utf-8") ).decode("ascii") - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "apiVersion": "v1", "kind": "Secret", @@ -529,11 +516,11 @@ def to_primitive(self) -> Dict[str, Any]: class SecretRef: name: str - def to_primitive(self) -> Dict[str, str]: + def to_primitive(self) -> dict[str, str]: return {"name": self.name} @classmethod - def from_primitive(cls, payload: Dict[str, str]) -> "SecretRef": + def from_primitive(cls, payload: dict[str, str]) -> "SecretRef": return cls(**payload) @@ -548,7 +535,7 @@ class Toleration: value: str = "" effect: str = "" - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "key": self.key, "operator": self.operator, @@ -574,7 +561,7 @@ def requires_no_values(self) -> bool: class NodeSelectorRequirement: key: str operator: NodeSelectorOperator - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) def __post_init__(self) -> None: if not self.key: @@ -594,8 +581,8 @@ def create_exists(cls, key: str) -> "NodeSelectorRequirement": def create_does_not_exist(cls, key: str) -> "NodeSelectorRequirement": return cls(key=key, operator=NodeSelectorOperator.DOES_NOT_EXIST) - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {"key": self.key, "operator": self.operator.value} + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = {"key": self.key, "operator": self.operator.value} if self.values: payload["values"] = self.values.copy() return payload @@ -603,13 +590,13 @@ def to_primitive(self) -> Dict[str, Any]: @dataclass(frozen=True) class NodeSelectorTerm: - match_expressions: List[NodeSelectorRequirement] + match_expressions: list[NodeSelectorRequirement] def __post_init__(self) -> None: if not self.match_expressions: raise ValueError("no expressions") - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return { "matchExpressions": [expr.to_primitive() for expr in self.match_expressions] } @@ -620,21 +607,21 @@ class NodePreferredSchedulingTerm: preference: NodeSelectorTerm weight: int = 100 - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return {"preference": self.preference.to_primitive(), "weight": self.weight} @dataclass(frozen=True) class NodeAffinity: - required: List[NodeSelectorTerm] = field(default_factory=list) - preferred: List[NodePreferredSchedulingTerm] = field(default_factory=list) + required: list[NodeSelectorTerm] = field(default_factory=list) + preferred: list[NodePreferredSchedulingTerm] = field(default_factory=list) def __post_init__(self) -> None: if not self.required and not self.preferred: raise ValueError("no terms") - def to_primitive(self) -> Dict[str, Any]: - payload: Dict[str, Any] = {} + def to_primitive(self) -> dict[str, Any]: + payload: dict[str, Any] = {} if self.required: payload["requiredDuringSchedulingIgnoredDuringExecution"] = { "nodeSelectorTerms": [term.to_primitive() for term in self.required] @@ -663,20 +650,20 @@ def __repr__(self) -> str: class PodDescriptor: name: str image: str - command: List[str] = field(default_factory=list) - args: List[str] = field(default_factory=list) + command: list[str] = field(default_factory=list) + args: list[str] = field(default_factory=list) working_dir: Optional[str] = None - env: Dict[str, str] = field(default_factory=dict) + env: dict[str, str] = field(default_factory=dict) # TODO (artem): create base type `EnvVar` and merge `env` and `secret_env` - secret_env_list: List[SecretEnvVar] = field(default_factory=list) - volume_mounts: List[VolumeMount] = field(default_factory=list) - volumes: List[Volume] = field(default_factory=list) + secret_env_list: list[SecretEnvVar] = field(default_factory=list) + volume_mounts: list[VolumeMount] = field(default_factory=list) + volumes: list[Volume] = field(default_factory=list) resources: Optional[Resources] = None - node_selector: Dict[str, str] = field(default_factory=dict) - tolerations: List[Toleration] = field(default_factory=list) + node_selector: dict[str, str] = field(default_factory=dict) + tolerations: list[Toleration] = field(default_factory=list) node_affinity: Optional[NodeAffinity] = None - labels: Dict[str, str] = field(default_factory=dict) - annotations: Dict[str, str] = field(default_factory=dict) + labels: dict[str, str] = field(default_factory=dict) + annotations: dict[str, str] = field(default_factory=dict) port: Optional[int] = None health_check_path: str = "/" @@ -684,7 +671,7 @@ class PodDescriptor: status: Optional["PodStatus"] = None - image_pull_secrets: List[SecretRef] = field(default_factory=list) + image_pull_secrets: list[SecretRef] = field(default_factory=list) # TODO (A Danshyn 12/09/2018): expose readiness probe properly readiness_probe: bool = False @@ -705,7 +692,7 @@ def _process_storage_volumes( cls, container: Container, storage_volume_factory: Optional[Callable[[ContainerVolume], Volume]] = None, - ) -> Tuple[List[Volume], List[VolumeMount]]: + ) -> tuple[list[Volume], list[VolumeMount]]: if not storage_volume_factory: return [], [] @@ -725,7 +712,7 @@ def _process_secret_volumes( cls, container: Container, secret_volume_factory: Optional[Callable[[str], SecretVolume]] = None, - ) -> Tuple[List[SecretVolume], List[VolumeMount]]: + ) -> tuple[list[SecretVolume], list[VolumeMount]]: path_to_secret_volumes = container.get_path_to_secret_volumes() if not secret_volume_factory: return [], [] @@ -744,12 +731,12 @@ def _process_secret_volumes( @classmethod def _process_disk_volumes( - cls, disk_volumes: List[DiskContainerVolume] - ) -> Tuple[List[PVCDiskVolume], List[VolumeMount]]: + cls, disk_volumes: list[DiskContainerVolume] + ) -> tuple[list[PVCDiskVolume], list[VolumeMount]]: pod_volumes = [] volume_mounts = [] - pvc_volumes: Dict[str, PVCDiskVolume] = dict() + pvc_volumes: dict[str, PVCDiskVolume] = dict() for index, disk_volume in enumerate(disk_volumes, 1): pvc_volume = pvc_volumes.get(disk_volume.disk.disk_id) if pvc_volume is None: @@ -767,14 +754,14 @@ def from_job_request( job_request: JobRequest, storage_volume_factory: Optional[Callable[[ContainerVolume], Volume]] = None, secret_volume_factory: Optional[Callable[[str], SecretVolume]] = None, - image_pull_secret_names: Optional[List[str]] = None, - node_selector: Optional[Dict[str, str]] = None, - tolerations: Optional[List[Toleration]] = None, + image_pull_secret_names: Optional[list[str]] = None, + node_selector: Optional[dict[str, str]] = None, + tolerations: Optional[list[Toleration]] = None, node_affinity: Optional[NodeAffinity] = None, - labels: Optional[Dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, priority_class_name: Optional[str] = None, restart_policy: PodRestartPolicy = PodRestartPolicy.NEVER, - meta_env: Optional[Dict[str, str]] = None, + meta_env: Optional[dict[str, str]] = None, privileged: bool = False, ) -> "PodDescriptor": container = job_request.container @@ -817,7 +804,7 @@ def from_job_request( else: image_pull_secrets = [] - annotations: Dict[str, str] = {} + annotations: dict[str, str] = {} if container.resources.tpu: annotations[ cls.tpu_version_annotation_key @@ -853,15 +840,15 @@ def from_job_request( ) @property - def env_list(self) -> List[Dict[str, str]]: + def env_list(self) -> list[dict[str, str]]: return [dict(name=name, value=value) for name, value in self.env.items()] - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: volume_mounts = [mount.to_primitive() for mount in self.volume_mounts] volumes = [volume.to_primitive() for volume in self.volumes] env_list = self.env_list + [env.to_primitive() for env in self.secret_env_list] - container_payload: Dict[str, Any] = { + container_payload: dict[str, Any] = { "name": f"{self.name}", "image": f"{self.image}", "imagePullPolicy": "Always", @@ -900,7 +887,7 @@ def to_primitive(self) -> Dict[str, Any]: ) ) - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "kind": "Pod", "apiVersion": "v1", "metadata": {"name": self.name}, @@ -931,13 +918,13 @@ def to_primitive(self) -> Dict[str, Any]: payload["spec"]["priorityClassName"] = self.priority_class_name return payload - def _to_primitive_ports(self) -> List[Dict[str, int]]: + def _to_primitive_ports(self) -> list[dict[str, int]]: ports = [] if self.port: ports.append({"containerPort": self.port}) return ports - def _to_primitive_readiness_probe(self) -> Dict[str, Any]: + def _to_primitive_readiness_probe(self) -> dict[str, Any]: if not self.readiness_probe: return {} @@ -951,7 +938,7 @@ def _to_primitive_readiness_probe(self) -> Dict[str, Any]: return {} @classmethod - def _assert_resource_kind(cls, expected_kind: str, payload: Dict[str, Any]) -> None: + def _assert_resource_kind(cls, expected_kind: str, payload: dict[str, Any]) -> None: kind = payload["kind"] if kind == "Status": _raise_status_job_exception(payload, job_id="") @@ -959,7 +946,7 @@ def _assert_resource_kind(cls, expected_kind: str, payload: Dict[str, Any]) -> N raise ValueError(f"unknown kind: {kind}") @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "PodDescriptor": + def from_primitive(cls, payload: dict[str, Any]) -> "PodDescriptor": cls._assert_resource_kind(expected_kind="Pod", payload=payload) metadata = payload["metadata"] @@ -1006,11 +993,11 @@ def from_primitive(cls, payload: Dict[str, Any]) -> "PodDescriptor": class ContainerStatus: - def __init__(self, payload: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, payload: Optional[dict[str, Any]] = None) -> None: self._payload = payload or {} @property - def _state(self) -> Dict[str, Any]: + def _state(self) -> dict[str, Any]: return self._payload.get("state", {}) @property @@ -1076,7 +1063,7 @@ class PodConditionType(enum.Enum): class PodCondition: # https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-conditions - def __init__(self, payload: Dict[str, Any]) -> None: + def __init__(self, payload: dict[str, Any]) -> None: self._payload = payload @property @@ -1111,11 +1098,11 @@ def type(self) -> PodConditionType: class KubernetesEvent: - def __init__(self, payload: Dict[str, Any]) -> None: + def __init__(self, payload: dict[str, Any]) -> None: self._payload = payload or {} @property - def involved_object(self) -> Dict[str, str]: + def involved_object(self) -> dict[str, str]: return self._payload["involvedObject"] @property @@ -1140,7 +1127,7 @@ def count(self) -> int: class PodStatus: - def __init__(self, payload: Dict[str, Any]) -> None: + def __init__(self, payload: dict[str, Any]) -> None: self._payload = payload self._container_status = self._init_container_status() @@ -1201,11 +1188,11 @@ def is_node_lost(self) -> bool: return self.reason == "NodeLost" @property - def conditions(self) -> List[PodCondition]: + def conditions(self) -> list[PodCondition]: return [PodCondition(val) for val in self._payload.get("conditions", [])] @classmethod - def from_primitive(cls, payload: Dict[str, Any]) -> "PodStatus": + def from_primitive(cls, payload: dict[str, Any]) -> "PodStatus": return cls(payload) @@ -1225,7 +1212,7 @@ class PodExec: def __init__(self, ws: aiohttp.ClientWebSocketResponse) -> None: self._ws = ws - self._channels: DefaultDict[ExecChannel, Stream] = DefaultDict(Stream) + self._channels: defaultdict[ExecChannel, Stream] = defaultdict(Stream) loop = asyncio.get_event_loop() self._reader_task = loop.create_task(self._read_data()) self._exit_code = loop.create_future() @@ -1288,7 +1275,7 @@ async def __aenter__(self) -> "PodExec": async def __aexit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: @@ -1314,7 +1301,7 @@ class NodeTaint: value: str effect: str = "NoSchedule" - def to_primitive(self) -> Dict[str, Any]: + def to_primitive(self) -> dict[str, Any]: return {"key": self.key, "value": self.value, "effect": self.effect} @@ -1334,7 +1321,7 @@ def __init__( conn_timeout_s: int = 300, read_timeout_s: int = 100, conn_pool_size: int = 100, - trace_configs: Optional[List[aiohttp.TraceConfig]] = None, + trace_configs: Optional[list[aiohttp.TraceConfig]] = None, ) -> None: self._base_url = base_url self._namespace = namespace @@ -1504,7 +1491,7 @@ def _generate_pvc_url( all_pvcs_url = self._generate_all_pvcs_url(namespace_name) return f"{all_pvcs_url}/{pvc_name}" - async def _request(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]: assert self._client async with self._client.request(*args, **kwargs) as response: # TODO (A Danshyn 05/21/18): check status code etc @@ -1547,15 +1534,15 @@ async def _delete_resource_url(self, url: str, uid: Optional[str] = None) -> Non async def get_endpoint( self, name: str, namespace: Optional[str] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = self._generate_endpoint_url(name, namespace or self._namespace) return await self._request(method="GET", url=url) async def create_node( self, name: str, - capacity: Dict[str, Any], - labels: Optional[Dict[str, str]] = None, + capacity: dict[str, Any], + labels: Optional[dict[str, str]] = None, taints: Optional[Sequence[NodeTaint]] = None, ) -> None: taints = taints or [] @@ -1586,8 +1573,8 @@ async def create_pod(self, descriptor: PodDescriptor) -> PodDescriptor: return pod async def set_raw_pod_status( - self, name: str, payload: Dict[str, Any] - ) -> Dict[str, Any]: + self, name: str, payload: dict[str, Any] + ) -> dict[str, Any]: url = self._generate_pod_url(name) + "/status" return await self._request(method="PUT", url=url, json=payload) @@ -1596,11 +1583,11 @@ async def get_pod(self, pod_name: str) -> PodDescriptor: payload = await self._request(method="GET", url=url) return PodDescriptor.from_primitive(payload) - async def get_raw_pod(self, name: str) -> Dict[str, Any]: + async def get_raw_pod(self, name: str) -> dict[str, Any]: url = self._generate_pod_url(name) return await self._request(method="GET", url=url) - async def get_raw_pods(self) -> Sequence[Dict[str, Any]]: + async def get_raw_pods(self) -> Sequence[dict[str, Any]]: payload = await self._request(method="GET", url=self._pods_url) return payload["items"] @@ -1626,9 +1613,9 @@ async def delete_pod(self, pod_name: str, force: bool = False) -> PodStatus: async def create_ingress( self, name: str, - rules: Optional[List[IngressRule]] = None, - annotations: Optional[Dict[str, str]] = None, - labels: Optional[Dict[str, str]] = None, + rules: Optional[list[IngressRule]] = None, + annotations: Optional[dict[str, str]] = None, + labels: Optional[dict[str, str]] = None, ) -> Ingress: rules = rules or [] annotations = annotations or {} @@ -1647,9 +1634,9 @@ async def get_ingress(self, name: str) -> Ingress: return Ingress.from_primitive(payload) async def delete_all_ingresses( - self, *, labels: Optional[Dict[str, str]] = None + self, *, labels: Optional[dict[str, str]] = None ) -> None: - params: Dict[str, str] = {} + params: dict[str, str] = {} if labels: params["labelSelector"] = ",".join( "=".join(item) for item in labels.items() @@ -1664,7 +1651,7 @@ async def delete_ingress(self, name: str) -> None: payload = await self._request(method="DELETE", url=url) self._check_status_payload(payload) - def _check_status_payload(self, payload: Dict[str, Any]) -> None: + def _check_status_payload(self, payload: dict[str, Any]) -> None: if payload["kind"] == "Status": if payload["status"] == "Failure": if payload.get("reason") == "AlreadyExists": @@ -1714,7 +1701,7 @@ async def get_service(self, name: str) -> Service: self._check_status_payload(payload) return Service.from_primitive(payload) - async def list_services(self, labels: Dict[str, str]) -> List[Service]: + async def list_services(self, labels: dict[str, str]) -> list[Service]: url = self._services_url labelSelector = ",".join(f"{label}={value}" for label, value in labels.items()) payload = await self._request( @@ -1751,7 +1738,7 @@ async def update_docker_secret( async def get_raw_secret( self, secret_name: str, namespace_name: Optional[str] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = self._generate_secret_url(secret_name, namespace_name) payload = await self._request(method="GET", url=url) self._check_status_payload(payload) @@ -1765,7 +1752,7 @@ async def delete_secret( async def get_raw_pvc( self, pvc_name: str, namespace_name: Optional[str] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = self._generate_pvc_url(pvc_name, namespace_name) payload = await self._request(method="GET", url=url) self._check_status_payload(payload) @@ -1773,7 +1760,7 @@ async def get_raw_pvc( async def get_pod_events( self, pod_id: str, namespace: str - ) -> List[KubernetesEvent]: + ) -> list[KubernetesEvent]: params = { "fieldSelector": ( "involvedObject.kind=Pod" @@ -1843,12 +1830,12 @@ async def wait_pod_is_terminated( async def create_default_network_policy( self, name: str, - pod_labels: Dict[str, str], + pod_labels: dict[str, str], namespace_name: Optional[str] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: assert pod_labels # https://tools.ietf.org/html/rfc1918#section-3 - rules: List[Dict[str, Any]] = [ + rules: list[dict[str, Any]] = [ # allowing pods to connect to public networks only { "to": [ @@ -1885,11 +1872,11 @@ async def create_egress_network_policy( self, name: str, *, - pod_labels: Dict[str, str], - rules: List[Dict[str, Any]], + pod_labels: dict[str, str], + rules: list[dict[str, Any]], namespace_name: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + labels: Optional[dict[str, str]] = None, + ) -> dict[str, Any]: assert pod_labels assert rules labels = labels or {} @@ -1911,7 +1898,7 @@ async def create_egress_network_policy( async def get_network_policy( self, name: str, namespace_name: Optional[str] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = self._generate_network_policy_url(name, namespace_name) payload = await self._request(method="GET", url=url) self._check_status_payload(payload) diff --git a/platform_api/orchestrator/kube_orchestrator.py b/platform_api/orchestrator/kube_orchestrator.py index 5f21d851b..049295e21 100644 --- a/platform_api/orchestrator/kube_orchestrator.py +++ b/platform_api/orchestrator/kube_orchestrator.py @@ -3,9 +3,10 @@ import operator import secrets from collections import defaultdict +from collections.abc import Iterable, Sequence from dataclasses import replace from datetime import datetime, timezone -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Union import aiohttp @@ -122,7 +123,7 @@ def __init__( registry_config: RegistryConfig, orchestrator_config: OrchestratorConfig, kube_config: KubeConfig, - trace_configs: Optional[List[aiohttp.TraceConfig]] = None, + trace_configs: Optional[list[aiohttp.TraceConfig]] = None, ) -> None: self._loop = asyncio.get_event_loop() self._storage_configs = storage_configs @@ -270,7 +271,7 @@ async def _create_pod_network_policy(self, job: Job) -> None: name = self._get_job_pod_name(job) pod_labels = self._get_job_labels(job) - rules: List[Dict[str, Any]] = [ + rules: list[dict[str, Any]] = [ # allowing the pod to connect to TPU nodes within internal network {"to": [{"ipBlock": {"cidr": tpu_ipv4_cidr_block}}]} ] @@ -361,8 +362,8 @@ def _get_cheapest_pool_types(self, job: Job) -> Sequence[ResourcePoolType]: # because during job lifetime node pools, presets can change and # node affinity assigned to the job won't be valid anymore. container_resources = job.request.container.resources - TKey = Tuple[int, float, int] - pool_types: Dict[TKey, List[ResourcePoolType]] = defaultdict(list) + TKey = tuple[int, float, int] + pool_types: dict[TKey, list[ResourcePoolType]] = defaultdict(list) for pool_type in self._orchestrator_config.resource_pool_types: # Schedule jobs only on preemptible nodes if such node specified @@ -418,23 +419,23 @@ def _update_pod_container_resources( new_resources = replace(pod.resources, memory_request=1024) return replace(pod, resources=new_resources) - def _get_user_pod_labels(self, job: Job) -> Dict[str, str]: + def _get_user_pod_labels(self, job: Job) -> dict[str, str]: return {"platform.neuromation.io/user": job.owner.replace("/", "--")} - def _get_job_labels(self, job: Job) -> Dict[str, str]: + def _get_job_labels(self, job: Job) -> dict[str, str]: return {"platform.neuromation.io/job": job.id} - def _get_preset_labels(self, job: Job) -> Dict[str, str]: + def _get_preset_labels(self, job: Job) -> dict[str, str]: if job.preset_name: return {"platform.neuromation.io/preset": job.preset_name} return {} - def _get_gpu_labels(self, job: Job) -> Dict[str, str]: + def _get_gpu_labels(self, job: Job) -> dict[str, str]: if not job.has_gpu or not job.gpu_model_id: return {} return {"platform.neuromation.io/gpu-model": job.gpu_model_id} - def _get_pod_labels(self, job: Job) -> Dict[str, str]: + def _get_pod_labels(self, job: Job) -> dict[str, str]: labels = self._get_job_labels(job) labels.update(self._get_user_pod_labels(job)) labels.update(self._get_gpu_labels(job)) @@ -444,7 +445,7 @@ def _get_pod_labels(self, job: Job) -> Dict[str, str]: def _get_pod_restart_policy(self, job: Job) -> PodRestartPolicy: return self._restart_policy_map[job.restart_policy] - async def get_missing_disks(self, disks: List[Disk]) -> List[Disk]: + async def get_missing_disks(self, disks: list[Disk]) -> list[Disk]: assert disks, "no disks" missing = [] for disk in disks: @@ -467,8 +468,8 @@ async def get_missing_disks(self, disks: List[Disk]) -> List[Disk]: return sorted(missing, key=lambda disk: disk.disk_id) async def get_missing_secrets( - self, secret_path: str, secret_names: List[str] - ) -> List[str]: + self, secret_path: str, secret_names: list[str] + ) -> list[str]: assert secret_names, "no sec names" user_secret_name = self._get_k8s_secret_name(secret_path) try: @@ -524,7 +525,7 @@ async def start_job( def _get_pod_tolerations( self, job: Job, tolerate_unreachable_node: bool = False - ) -> List[Toleration]: + ) -> list[Toleration]: tolerations = [ Toleration( key=self._kube_config.jobs_pod_job_toleration_key, @@ -563,8 +564,8 @@ def _get_pod_node_affinity( # `NodeSelectorTerm`s is satisfied. # `NodeSelectorTerm` is satisfied only if its `match_expressions` are # satisfied. - required_terms: List[NodeSelectorTerm] = [] - preferred_terms: List[NodePreferredSchedulingTerm] = [] + required_terms: list[NodeSelectorTerm] = [] + preferred_terms: list[NodePreferredSchedulingTerm] = [] if self._kube_config.node_label_node_pool: for pool_type in pool_types: @@ -733,7 +734,7 @@ async def _create_service( service = service.make_named(name) return await self._client.create_service(service) - async def _get_services(self, job: Job) -> List[Service]: + async def _get_services(self, job: Job) -> list[Service]: return await self._client.list_services(self._get_job_labels(job)) async def _delete_service( @@ -766,8 +767,8 @@ async def delete_job(self, job: Job) -> JobStatus: def _get_job_ingress_name(self, job: Job) -> str: return job.id - def _get_ingress_annotations(self, job: Job) -> Dict[str, str]: - annotations: Dict[str, str] = {} + def _get_ingress_annotations(self, job: Job) -> dict[str, str]: + annotations: dict[str, str] = {} if self._kube_config.jobs_ingress_class == "traefik": annotations = { "kubernetes.io/ingress.class": "traefik", @@ -793,13 +794,13 @@ def _get_ingress_annotations(self, job: Job) -> Dict[str, str]: def _get_job_name_ingress_labels( self, job: Job, service: Service - ) -> Dict[str, str]: + ) -> dict[str, str]: labels = self._get_user_pod_labels(job) if job.name: labels["platform.neuromation.io/job-name"] = job.name return labels - def _get_ingress_labels(self, job: Job, service: Service) -> Dict[str, str]: + def _get_ingress_labels(self, job: Job, service: Service) -> dict[str, str]: return {**service.labels, **self._get_job_name_ingress_labels(job, service)} async def _delete_ingresses_by_job_name(self, job: Job, service: Service) -> None: diff --git a/platform_api/orchestrator/poller_service.py b/platform_api/orchestrator/poller_service.py index 112d37247..0b0e9bf01 100644 --- a/platform_api/orchestrator/poller_service.py +++ b/platform_api/orchestrator/poller_service.py @@ -1,11 +1,12 @@ import abc import logging from collections import defaultdict +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import partial -from typing import AsyncIterator, Callable, Dict, List, Optional, Tuple, Union +from typing import Optional, Union from aiohttp import ClientResponseError from neuro_admin_client import AdminClient @@ -39,8 +40,8 @@ @dataclass(frozen=True) class SchedulingResult: - jobs_to_update: List[JobRecord] - jobs_to_suspend: List[JobRecord] + jobs_to_update: list[JobRecord] + jobs_to_suspend: list[JobRecord] current_datetime_factory = partial(datetime.now, timezone.utc) @@ -93,11 +94,11 @@ async def _get_org_running_jobs_quota( async def _enforce_running_job_quota( self, raw_result: SchedulingResult ) -> SchedulingResult: - jobs_to_update: List[JobRecord] = [] + jobs_to_update: list[JobRecord] = [] # Grouping by (username, cluster_name, org_name): - grouped_jobs: Dict[ - Tuple[str, str, Optional[str]], List[JobRecord] + grouped_jobs: dict[ + tuple[str, str, Optional[str]], list[JobRecord] ] = defaultdict(list) for record in raw_result.jobs_to_update: grouped_jobs[(record.owner, record.cluster_name, record.org_name)].append( @@ -105,8 +106,8 @@ async def _enforce_running_job_quota( ) def _filter_our_for_quota( - quota: Optional[int], jobs: List[JobRecord] - ) -> List[JobRecord]: + quota: Optional[int], jobs: list[JobRecord] + ) -> list[JobRecord]: if quota is not None: materialized_jobs = [job for job in jobs if job.materialized] not_materialized = [job for job in jobs if not job.materialized] @@ -127,8 +128,8 @@ def _filter_our_for_quota( jobs_to_update.extend(_filter_our_for_quota(quota, jobs)) # Grouping by (cluster_name, org_name): - grouped_by_org_jobs: Dict[ - Tuple[str, Optional[str]], List[JobRecord] + grouped_by_org_jobs: dict[ + tuple[str, Optional[str]], list[JobRecord] ] = defaultdict(list) for record in jobs_to_update: grouped_by_org_jobs[(record.cluster_name, record.org_name)].append(record) @@ -144,9 +145,9 @@ def _filter_our_for_quota( jobs_to_suspend=raw_result.jobs_to_suspend, ) - async def schedule(self, unfinished: List[JobRecord]) -> SchedulingResult: - jobs_to_update: List[JobRecord] = [] - jobs_to_suspend: List[JobRecord] = [] + async def schedule(self, unfinished: list[JobRecord]) -> SchedulingResult: + jobs_to_update: list[JobRecord] = [] + jobs_to_suspend: list[JobRecord] = [] now = self._current_datetime_factory() # Always start/update not scheduled jobs @@ -203,10 +204,10 @@ async def schedule(self, unfinished: List[JobRecord]) -> SchedulingResult: class JobsPollerApi(abc.ABC): - async def get_unfinished_jobs(self) -> List[JobRecord]: + async def get_unfinished_jobs(self) -> list[JobRecord]: raise NotImplementedError - async def get_jobs_for_deletion(self, *, delay: timedelta) -> List[JobRecord]: + async def get_jobs_for_deletion(self, *, delay: timedelta) -> list[JobRecord]: raise NotImplementedError async def push_status(self, job_id: str, status: JobStatusItem) -> None: diff --git a/platform_api/poller_main.py b/platform_api/poller_main.py index 3608ceee6..4cfca3427 100644 --- a/platform_api/poller_main.py +++ b/platform_api/poller_main.py @@ -1,7 +1,8 @@ import asyncio import logging +from collections.abc import AsyncIterator, Callable from contextlib import AsyncExitStack -from typing import AsyncIterator, Callable, List, Optional +from typing import Optional import aiohttp.web from aiohttp.web_urldispatcher import AbstractRoute @@ -37,7 +38,7 @@ class Handler: def __init__(self, app: aiohttp.web.Application): self._app = app - def register(self, app: aiohttp.web.Application) -> List[AbstractRoute]: + def register(self, app: aiohttp.web.Application) -> list[AbstractRoute]: return app.add_routes((aiohttp.web.get("/ping", self.handle_ping),)) @notrace @@ -60,7 +61,7 @@ def _create_cluster(cluster_config: ClusterConfig) -> Cluster: return _create_cluster -def make_tracing_trace_configs(config: PollerConfig) -> List[aiohttp.TraceConfig]: +def make_tracing_trace_configs(config: PollerConfig) -> list[aiohttp.TraceConfig]: trace_configs = [] if config.zipkin: diff --git a/platform_api/resource.py b/platform_api/resource.py index 720261d44..2d2325845 100644 --- a/platform_api/resource.py +++ b/platform_api/resource.py @@ -1,8 +1,9 @@ import uuid +from collections.abc import Sequence from dataclasses import dataclass, field from decimal import Decimal from enum import Enum -from typing import Optional, Sequence +from typing import Optional @dataclass(frozen=True) diff --git a/platform_api/utils/asyncio.py b/platform_api/utils/asyncio.py index f89412243..3ad5c4fcc 100644 --- a/platform_api/utils/asyncio.py +++ b/platform_api/utils/asyncio.py @@ -2,17 +2,10 @@ import functools import logging import sys +from collections.abc import Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager from types import TracebackType -from typing import ( - Any, - AsyncContextManager, - Awaitable, - Callable, - Iterable, - Optional, - Type, - TypeVar, -) +from typing import Any, Optional, TypeVar async def run_and_log_exceptions(coros: Iterable[Awaitable[Any]]) -> None: @@ -31,7 +24,7 @@ async def run_and_log_exceptions(coros: Iterable[Awaitable[Any]]) -> None: from contextlib import aclosing else: - class aclosing(AsyncContextManager[T_co]): + class aclosing(AbstractAsyncContextManager[T_co]): def __init__(self, thing: T_co): self.thing = thing @@ -40,7 +33,7 @@ async def __aenter__(self) -> T_co: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ) -> None: @@ -49,9 +42,9 @@ async def __aexit__( def asyncgeneratorcontextmanager( func: Callable[..., T_co] -) -> Callable[..., AsyncContextManager[T_co]]: +) -> Callable[..., AbstractAsyncContextManager[T_co]]: @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> AsyncContextManager[T_co]: + def wrapper(*args: Any, **kwargs: Any) -> AbstractAsyncContextManager[T_co]: return aclosing(func(*args, **kwargs)) return wrapper diff --git a/platform_api/utils/retry.py b/platform_api/utils/retry.py index 39283a53c..a96e1d82d 100644 --- a/platform_api/utils/retry.py +++ b/platform_api/utils/retry.py @@ -1,6 +1,7 @@ import asyncio import logging -from typing import Any, Callable, Iterator, Tuple, Type +from collections.abc import Callable, Iterator +from typing import Any log = logging.getLogger(__name__) @@ -10,7 +11,7 @@ class retries: def __init__( self, msg: str, - catch: Tuple[Type[Exception], ...], + catch: tuple[type[Exception], ...], attempts: int = 10, logger: Callable[[str], None] = log.info, ) -> None: @@ -34,7 +35,7 @@ async def __aenter__(self) -> None: pass async def __aexit__( - self, type: Type[BaseException], value: BaseException, tb: Any + self, type: type[BaseException], value: BaseException, tb: Any ) -> bool: if type is None: # Stop iteration diff --git a/platform_api/utils/stream.py b/platform_api/utils/stream.py index 5ba09f32c..792561d63 100644 --- a/platform_api/utils/stream.py +++ b/platform_api/utils/stream.py @@ -1,5 +1,6 @@ import asyncio -from typing import Deque, Optional +from collections import deque +from typing import Optional class Stream: @@ -11,7 +12,7 @@ class Stream: def __init__(self) -> None: self._loop = asyncio.get_event_loop() self._waiter: Optional[asyncio.Future[None]] = None - self._data: Deque[bytes] = Deque() + self._data: deque[bytes] = deque() self._closed = False @property diff --git a/platform_api/utils/update_notifier.py b/platform_api/utils/update_notifier.py index 4b3c2ea4a..11888d72e 100644 --- a/platform_api/utils/update_notifier.py +++ b/platform_api/utils/update_notifier.py @@ -1,8 +1,8 @@ import asyncio import logging from abc import ABC, abstractmethod -from contextlib import asynccontextmanager, suppress -from typing import Any, AsyncContextManager, Callable, List, Optional +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress +from typing import Any, Callable, Optional import asyncpg from sqlalchemy.ext.asyncio import AsyncEngine @@ -30,13 +30,13 @@ async def notify(self) -> None: @abstractmethod def listen_to_updates( self, listener: Callback - ) -> AsyncContextManager[Subscription]: + ) -> AbstractAsyncContextManager[Subscription]: pass class InMemoryNotifier(Notifier): def __init__(self) -> None: - self._callbacks: List[Callback] = [] + self._callbacks: list[Callback] = [] async def notify(self) -> None: for callback in self._callbacks: @@ -124,7 +124,7 @@ async def notify(self) -> None: await self._inner_notifier.notify() class _Subscription(Subscription): - _inner_manager: Optional[AsyncContextManager[Subscription]] = None + _inner_manager: Optional[AbstractAsyncContextManager[Subscription]] = None _subscription: Optional[Subscription] = None _task: Optional["asyncio.Task[None]"] = None diff --git a/setup.cfg b/setup.cfg index b8595df07..5f3b88cdc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ classifiers = [options] zip_safe = False -python_requires = >=3.8 +python_requires = >=3.9 include_package_data = True packages = find: platforms = any @@ -40,7 +40,6 @@ console_scripts = dev = mypy==0.920 pre-commit==2.16.0 - types-setuptools==57.4.4 aiodocker==0.21.0 codecov==2.1.12 pytest==6.2.5 diff --git a/tests/conftest.py b/tests/conftest.py index 04d281f82..41c527b57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ +from collections.abc import Iterator from contextlib import contextmanager -from typing import Iterator, Type from uuid import uuid1 import pytest @@ -18,7 +18,7 @@ @contextmanager -def not_raises(exc_cls: Type[Exception]) -> Iterator[None]: +def not_raises(exc_cls: type[Exception]) -> Iterator[None]: try: yield except exc_cls as exc: diff --git a/tests/integration/admin.py b/tests/integration/admin.py index 42c20f19a..0025cda6b 100644 --- a/tests/integration/admin.py +++ b/tests/integration/admin.py @@ -1,6 +1,7 @@ import asyncio +from collections.abc import AsyncGenerator, AsyncIterator from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, List +from typing import Any import aiodocker import aiohttp.web @@ -17,7 +18,7 @@ @pytest.fixture(scope="session") async def fake_config_app() -> AsyncIterator[URL]: app = aiohttp.web.Application() - clusters: List[Any] = [] + clusters: list[Any] = [] async def add_cluster(request: aiohttp.web.Request) -> aiohttp.web.Response: payload = await request.json() diff --git a/tests/integration/api.py b/tests/integration/api.py index c7b702d9e..23df143d8 100644 --- a/tests/integration/api.py +++ b/tests/integration/api.py @@ -1,7 +1,8 @@ import asyncio import json import time -from typing import Any, AsyncIterator, Callable, Dict, List, NamedTuple, Optional, Set +from collections.abc import AsyncIterator, Callable +from typing import Any, NamedTuple, Optional import aiohttp import aiohttp.web @@ -157,13 +158,13 @@ async def client() -> AsyncIterator[aiohttp.ClientSession]: class JobsClient: def __init__( - self, api_config: ApiConfig, client: ClientSession, headers: Dict[str, str] + self, api_config: ApiConfig, client: ClientSession, headers: dict[str, str] ) -> None: self._api_config = api_config self._client = client self._headers = headers - async def create_job(self, payload: Dict[str, Any]) -> Dict[str, Any]: + async def create_job(self, payload: dict[str, Any]) -> dict[str, Any]: url = self._api_config.jobs_base_url async with self._client.post(url, headers=self._headers, json=payload) as resp: assert resp.status == HTTPAccepted.status_code, await resp.text() @@ -171,7 +172,7 @@ async def create_job(self, payload: Dict[str, Any]) -> Dict[str, Any]: assert result["status"] == "pending" return result - async def get_all_jobs(self, params: Any = None) -> List[Dict[str, Any]]: + async def get_all_jobs(self, params: Any = None) -> list[dict[str, Any]]: url = self._api_config.jobs_base_url headers = self._headers.copy() headers["Accept"] = "application/x-ndjson" @@ -189,8 +190,8 @@ async def get_all_jobs(self, params: Any = None) -> List[Dict[str, Any]]: async def get_job_by_id( self, job_id: str, - headers: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + headers: Optional[dict[str, str]] = None, + ) -> dict[str, Any]: url = self._api_config.generate_job_url(job_id) async with self._client.get(url, headers=headers or self._headers) as response: response_text = await response.text() @@ -201,7 +202,7 @@ async def get_job_by_id( async def get_job_materialized_by_id( self, job_id: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> bool: url = ( self._api_config.generate_job_url(job_id) @@ -219,11 +220,11 @@ async def long_polling_by_job_id( interval_s: float = 0.5, max_time: float = 300, unreachable_optimization: bool = True, - headers: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + headers: Optional[dict[str, str]] = None, + ) -> dict[str, Any]: # A little optimization with unreachable statuses - unreachable_statuses_map: Dict[str, List[str]] = { + unreachable_statuses_map: dict[str, list[str]] = { JobStatus.PENDING.value: [ JobStatus.RUNNING.value, JobStatus.SUCCEEDED.value, @@ -246,7 +247,7 @@ async def long_polling_by_job_id( JobStatus.FAILED.value, ], } - stop_statuses: List[str] = [] + stop_statuses: list[str] = [] if unreachable_optimization and status in unreachable_statuses_map: stop_statuses = unreachable_statuses_map[status] @@ -265,7 +266,7 @@ async def long_polling_by_job_id( async def wait_job_creation( self, job_id: str, interval_s: float = 0.5, max_time: float = 300 - ) -> Dict[str, Any]: + ) -> dict[str, Any]: t0 = time.monotonic() while True: response = await self.get_job_by_id(job_id) @@ -298,7 +299,7 @@ async def delete_job( self, job_id: str, assert_success: bool = True, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> None: url = self._api_config.generate_job_url(job_id) async with self._client.delete( @@ -313,7 +314,7 @@ async def drop_job( self, job_id: str, assert_success: bool = True, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> None: url = self._api_config.generate_job_url(job_id) + "/drop" async with self._client.post(url, headers=headers or self._headers) as response: @@ -327,7 +328,7 @@ async def drop_progress( job_id: str, logs_removed: Optional[bool] = None, assert_success: bool = True, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> None: url = self._api_config.generate_job_url(job_id) + "/drop_progress" payload = {} @@ -346,7 +347,7 @@ async def drop_progress( async def jobs_client_factory( api: ApiConfig, client: ClientSession ) -> AsyncIterator[Callable[[_User], JobsClient]]: - jobs_clients: List[JobsClient] = [] + jobs_clients: list[JobsClient] = [] def impl(user: _User) -> JobsClient: jobs_client = JobsClient(api, client, headers=user.headers) @@ -355,7 +356,7 @@ def impl(user: _User) -> JobsClient: yield impl - deleted: Set[str] = set() + deleted: set[str] = set() for jobs_client in jobs_clients: try: jobs = await jobs_client.get_all_jobs() @@ -406,8 +407,8 @@ async def infinite_job( @pytest.fixture -def job_request_factory() -> Callable[[], Dict[str, Any]]: - def _factory(cluster_name: Optional[str] = None) -> Dict[str, Any]: +def job_request_factory() -> Callable[[], dict[str, Any]]: + def _factory(cluster_name: Optional[str] = None) -> dict[str, Any]: # Note: Optional fields (as "name") should not have a value here request = { "container": { @@ -427,6 +428,6 @@ def _factory(cluster_name: Optional[str] = None) -> Dict[str, Any]: @pytest.fixture async def job_submit( - job_request_factory: Callable[[], Dict[str, Any]] -) -> Dict[str, Any]: + job_request_factory: Callable[[], dict[str, Any]] +) -> dict[str, Any]: return job_request_factory() diff --git a/tests/integration/auth.py b/tests/integration/auth.py index 0e24e58f7..e49f348d0 100644 --- a/tests/integration/auth.py +++ b/tests/integration/auth.py @@ -1,18 +1,8 @@ import asyncio +from collections.abc import AsyncGenerator, AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import ( - AsyncGenerator, - AsyncIterator, - Callable, - Dict, - List, - Optional, - Protocol, - Tuple, - Union, - cast, -) +from typing import Optional, Protocol, Union, cast import aiodocker import pytest @@ -151,7 +141,7 @@ async def wait_for_auth_server( class _User: name: str token: str - clusters: List[str] = field(default_factory=list) + clusters: list[str] = field(default_factory=list) @property def cluster_name(self) -> str: @@ -159,7 +149,7 @@ def cluster_name(self) -> str: return self.clusters[0] @property - def headers(self) -> Dict[str, str]: + def headers(self) -> dict[str, str]: return {AUTHORIZATION: f"Bearer {self.token}"} @@ -173,7 +163,7 @@ async def __call__( self, name: Optional[str] = None, clusters: Optional[ - List[Union[Tuple[str, Balance, Quota], Tuple[str, str, Balance, Quota]]] + list[Union[tuple[str, Balance, Quota], tuple[str, str, Balance, Quota]]] ] = None, ) -> _User: ... @@ -190,7 +180,7 @@ async def regular_user_factory( async def _factory( name: Optional[str] = None, clusters: Optional[ - List[Union[Tuple[str, Balance, Quota], Tuple[str, str, Balance, Quota]]] + list[Union[tuple[str, Balance, Quota], tuple[str, str, Balance, Quota]]] ] = None, ) -> _User: if not name: @@ -201,10 +191,10 @@ async def _factory( for entry in clusters: org_name: Optional[str] = None if len(entry) == 3: - cluster, balance, quota = cast(Tuple[str, Balance, Quota], entry) + cluster, balance, quota = cast(tuple[str, Balance, Quota], entry) else: cluster, org_name, balance, quota = cast( - Tuple[str, str, Balance, Quota], entry + tuple[str, str, Balance, Quota], entry ) try: await admin_client.create_cluster(cluster) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3ae6086a4..62caf1383 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,22 +1,13 @@ import asyncio import json import uuid +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Mapping from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timezone from decimal import Decimal from pathlib import Path, PurePath -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, -) +from typing import Any, Optional from urllib.parse import urlsplit import aiohttp @@ -78,7 +69,7 @@ def event_loop() -> Iterator[asyncio.AbstractEventLoop]: @pytest.fixture(scope="session") -async def kube_config_payload() -> Dict[str, Any]: +async def kube_config_payload() -> dict[str, Any]: process = await asyncio.create_subprocess_exec( "kubectl", "config", "view", "-o", "json", stdout=asyncio.subprocess.PIPE ) @@ -88,7 +79,7 @@ async def kube_config_payload() -> Dict[str, Any]: @pytest.fixture(scope="session") -async def kube_config_cluster_payload(kube_config_payload: Dict[str, Any]) -> Any: +async def kube_config_cluster_payload(kube_config_payload: dict[str, Any]) -> Any: cluster_name = "minikube" clusters = { cluster["name"]: cluster["cluster"] @@ -99,7 +90,7 @@ async def kube_config_cluster_payload(kube_config_payload: Dict[str, Any]) -> An @pytest.fixture(scope="session") def cert_authority_data_pem( - kube_config_cluster_payload: Dict[str, Any] + kube_config_cluster_payload: dict[str, Any] ) -> Optional[str]: ca_path = kube_config_cluster_payload["certificate-authority"] if ca_path: @@ -108,7 +99,7 @@ def cert_authority_data_pem( @pytest.fixture(scope="session") -async def kube_config_user_payload(kube_config_payload: Dict[str, Any]) -> Any: +async def kube_config_user_payload(kube_config_payload: dict[str, Any]) -> Any: user_name = "minikube" users = {user["name"]: user["user"] for user in kube_config_payload["users"]} return users[user_name] @@ -246,8 +237,8 @@ async def orchestrator_config( @pytest.fixture(scope="session") def kube_config_factory( - kube_config_cluster_payload: Dict[str, Any], - kube_config_user_payload: Dict[str, Any], + kube_config_cluster_payload: dict[str, Any], + kube_config_user_payload: dict[str, Any], cert_authority_data_pem: Optional[str], ) -> Iterator[Callable[..., KubeConfig]]: cluster = kube_config_cluster_payload @@ -321,13 +312,13 @@ async def _create( @pytest.fixture(scope="session") -async def kube_ingress_ip(kube_config_cluster_payload: Dict[str, Any]) -> str: +async def kube_ingress_ip(kube_config_cluster_payload: dict[str, Any]) -> str: cluster = kube_config_cluster_payload return urlsplit(cluster["server"]).hostname class MyKubeClient(KubeClient): - _created_pvcs: List[str] + _created_pvcs: list[str] async def init(self) -> None: await super().init() @@ -373,7 +364,7 @@ async def delete_pvc( self._check_status_payload(payload) async def update_or_create_secret( - self, secret_name: str, namespace: str, data: Optional[Dict[str, str]] = None + self, secret_name: str, namespace: str, data: Optional[dict[str, str]] = None ) -> None: url = self._generate_all_secrets_url(namespace) data = data or {} @@ -394,7 +385,7 @@ async def wait_pod_scheduled( timeout_s: float = 5.0, interval_s: float = 1.0, ) -> None: - raw_pod: Optional[Dict[str, Any]] = None + raw_pod: Optional[dict[str, Any]] = None try: async with timeout(timeout_s): while True: @@ -675,7 +666,7 @@ def kube_node() -> str: @pytest.fixture -def default_node_capacity() -> Dict[str, Any]: +def default_node_capacity() -> dict[str, Any]: return {"pods": "110", "memory": "1Gi", "cpu": 2, "nvidia.com/gpu": 1} @@ -684,7 +675,7 @@ async def kube_node_gpu( kube_config: KubeConfig, kube_client: MyKubeClient, delete_node_later: Callable[[str], Awaitable[None]], - default_node_capacity: Dict[str, Any], + default_node_capacity: dict[str, Any], ) -> AsyncIterator[str]: node_name = str(uuid.uuid4()) await delete_node_later(node_name) @@ -736,7 +727,7 @@ async def kube_node_preemptible( kube_config_node_preemptible: KubeConfig, kube_client: MyKubeClient, delete_node_later: Callable[[str], Awaitable[None]], - default_node_capacity: Dict[str, Any], + default_node_capacity: dict[str, Any], ) -> AsyncIterator[str]: node_name = str(uuid.uuid4()) await delete_node_later(node_name) diff --git a/tests/integration/diskapi.py b/tests/integration/diskapi.py index 18037a6a0..b40fb809e 100644 --- a/tests/integration/diskapi.py +++ b/tests/integration/diskapi.py @@ -1,8 +1,9 @@ import asyncio import subprocess import sys -from contextlib import asynccontextmanager -from typing import Any, AsyncContextManager, AsyncIterator, Callable, Dict, Optional +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Optional import aiodocker import aiodocker.containers @@ -141,7 +142,7 @@ class DiskAPIClient: def __init__(self, cluster_name: str, url: URL, auth_token: str): self._cluster_name = cluster_name self._base_url = url / "api/v1" - headers: Dict[str, str] = {} + headers: dict[str, str] = {} if auth_token: headers["Authorization"] = f"Bearer {auth_token}" self._client = aiohttp.ClientSession(headers=headers) @@ -188,8 +189,8 @@ async def create_disk_api_client( @pytest.fixture async def disk_client_factory( disk_server_url: URL, -) -> Callable[[_User], AsyncContextManager[DiskAPIClient]]: - def _f(user: _User) -> AsyncContextManager[DiskAPIClient]: +) -> Callable[[_User], AbstractAsyncContextManager[DiskAPIClient]]: + def _f(user: _User) -> AbstractAsyncContextManager[DiskAPIClient]: return create_disk_api_client(user.cluster_name, disk_server_url, user.token) return _f diff --git a/tests/integration/docker.py b/tests/integration/docker.py index 6b0c52b68..aebdd014f 100644 --- a/tests/integration/docker.py +++ b/tests/integration/docker.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any import aiodocker import pytest diff --git a/tests/integration/notifications.py b/tests/integration/notifications.py index 9c79c054e..74c7a2d9a 100644 --- a/tests/integration/notifications.py +++ b/tests/integration/notifications.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator, NamedTuple, Tuple +from collections.abc import AsyncIterator +from typing import Any, NamedTuple import aiohttp.web import pytest @@ -19,7 +20,7 @@ def url(self) -> URL: return URL(f"http://{self.address.host}:{self.address.port}") @property - def requests(self) -> Tuple[Tuple[str, Any]]: + def requests(self) -> tuple[tuple[str, Any]]: return tuple(request for request in self.app["requests"]) # type: ignore diff --git a/tests/integration/postgres.py b/tests/integration/postgres.py index d9311abf6..49d473f51 100644 --- a/tests/integration/postgres.py +++ b/tests/integration/postgres.py @@ -1,5 +1,5 @@ import time -from typing import AsyncIterator +from collections.abc import AsyncIterator import aiodocker import aiodocker.containers diff --git a/tests/integration/secrets.py b/tests/integration/secrets.py index ff68e5b4b..a44cb087e 100644 --- a/tests/integration/secrets.py +++ b/tests/integration/secrets.py @@ -2,8 +2,9 @@ import base64 import subprocess import sys -from contextlib import asynccontextmanager -from typing import Any, AsyncContextManager, AsyncIterator, Callable, Dict, Optional +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Optional import aiodocker import aiodocker.containers @@ -138,7 +139,7 @@ class SecretsClient: def __init__(self, url: URL, user_name: str, user_token: str): self._base_url = url / "api/v1" self._user_name = user_name - headers: Dict[str, str] = {} + headers: dict[str, str] = {} if user_token: headers["Authorization"] = f"Bearer {user_token}" self._client = aiohttp.ClientSession(headers=headers) @@ -187,8 +188,8 @@ async def create_secrets_client( @pytest.fixture async def secrets_client_factory( secrets_server_url: URL, -) -> Callable[[_User], AsyncContextManager[SecretsClient]]: - def _f(user: _User) -> AsyncContextManager[SecretsClient]: +) -> Callable[[_User], AbstractAsyncContextManager[SecretsClient]]: + def _f(user: _User) -> AbstractAsyncContextManager[SecretsClient]: return create_secrets_client(secrets_server_url, user.name, user.token) return _f diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 1282e77ab..d1e634076 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -2,19 +2,10 @@ import json import re from collections import defaultdict +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from contextlib import AbstractAsyncContextManager from decimal import Decimal -from typing import ( - Any, - AsyncContextManager, - AsyncIterator, - Awaitable, - Callable, - Dict, - Iterator, - List, - Optional, - Tuple, -) +from typing import Any, Optional from unittest import mock import aiohttp.web @@ -53,7 +44,7 @@ def cluster_name() -> str: @pytest.fixture -def cluster_configs_payload() -> List[Dict[str, Any]]: +def cluster_configs_payload() -> list[dict[str, Any]]: return [ { "name": "cluster_name", @@ -180,14 +171,14 @@ async def test_clusters_sync( self, api: ApiConfig, client: aiohttp.ClientSession, - cluster_configs_payload: List[Dict[str, Any]], + cluster_configs_payload: list[dict[str, Any]], cluster_user: _User, ) -> None: cluster_registry: ClusterConfigRegistry = api.runner._app["config_app"][ "jobs_service" ]._cluster_registry - async def assert_cluster_names(names: List[str]) -> None: + async def assert_cluster_names(names: list[str]) -> None: async def _loop() -> None: while names != cluster_registry.cluster_names: await asyncio.sleep(0.1) @@ -339,7 +330,7 @@ async def test_config( }, ], } - expected_payload: Dict[str, Any] = { + expected_payload: dict[str, Any] = { "admin_url": f"{admin_url}", "clusters": [ expected_cluster_payload, @@ -460,7 +451,7 @@ async def test_config_with_orgs( }, ], } - expected_payload: Dict[str, Any] = { + expected_payload: dict[str, Any] = { "admin_url": f"{admin_url}", "clusters": [ expected_cluster_payload, @@ -573,7 +564,7 @@ async def test_config_with_oauth( }, ], } - expected_payload: Dict[str, Any] = { + expected_payload: dict[str, Any] = { "auth_url": "https://platform-auth0-url/auth", "token_url": "https://platform-auth0-url/token", "logout_url": "https://platform-auth0-url/logout", @@ -599,7 +590,7 @@ async def test_create_job_with_http( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -627,7 +618,7 @@ async def test_create_job_without_http( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -649,7 +640,7 @@ async def test_create_job_owner_with_slash( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, service_account_factory: ServiceAccountFactory, jobs_client_factory: Callable[[_User], JobsClient], @@ -683,7 +674,7 @@ async def test_create_job_with_org( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], service_account_factory: ServiceAccountFactory, regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], @@ -714,7 +705,7 @@ async def test_create_job_with_pass_config( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -738,7 +729,7 @@ async def test_create_job_with_wait_for_jobs_quota( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -760,7 +751,7 @@ async def test_create_job_with_privileged_flag( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -784,7 +775,7 @@ async def test_create_job_with_tty( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -813,11 +804,11 @@ async def _run_job_with_secrets( self, api: ApiConfig, client: aiohttp.ClientSession, jobs_client: JobsClient ) -> Callable[..., Awaitable[None]]: async def _run( - job_submit: Dict[str, Any], + job_submit: dict[str, Any], user: _User, *, - secret_env: Optional[Dict[str, str]] = None, - secret_volumes: Optional[Dict[str, str]] = None, + secret_env: Optional[dict[str, str]] = None, + secret_volumes: Optional[dict[str, str]] = None, ) -> None: job_id = "" try: @@ -854,7 +845,7 @@ async def test_create_job_with_secret_env_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -879,7 +870,7 @@ async def test_create_job_with_secret_volume_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -924,10 +915,12 @@ async def test_create_job_with_org_secret_volume_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, - secrets_client_factory: Callable[[_User], AsyncContextManager[SecretsClient]], + secrets_client_factory: Callable[ + [_User], AbstractAsyncContextManager[SecretsClient] + ], ) -> None: org_user = await regular_user_factory( clusters=[ @@ -978,10 +971,12 @@ async def test_create_job_with_org_secret_env_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, - secrets_client_factory: Callable[[_User], AsyncContextManager[SecretsClient]], + secrets_client_factory: Callable[ + [_User], AbstractAsyncContextManager[SecretsClient] + ], _run_job_with_secrets: Callable[..., Awaitable[None]], ) -> None: org_user = await regular_user_factory( @@ -1011,7 +1006,7 @@ async def test_create_job_with_secret_volume_user_with_slash_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, service_account_factory: ServiceAccountFactory, jobs_client_factory: Callable[[_User], JobsClient], @@ -1066,7 +1061,7 @@ async def test_create_job_with_secret_env_user_with_slash_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, service_account_factory: ServiceAccountFactory, jobs_client_factory: Callable[[_User], JobsClient], @@ -1098,7 +1093,7 @@ async def test_create_job_with_disk_volume_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_disk_api_client: DiskAPIClient, @@ -1148,10 +1143,12 @@ async def test_create_job_with_org_disk_volume_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, - disk_client_factory: Callable[[_User], AsyncContextManager[DiskAPIClient]], + disk_client_factory: Callable[ + [_User], AbstractAsyncContextManager[DiskAPIClient] + ], ) -> None: org_user = await regular_user_factory( clusters=[ @@ -1206,7 +1203,7 @@ async def test_create_job_with_disk_volume_user_with_slash_single_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, service_account_factory: ServiceAccountFactory, jobs_client_factory: Callable[[_User], JobsClient], @@ -1262,7 +1259,7 @@ async def test_create_job_with_one_disk_volume_multiple_mounts_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_disk_api_client: DiskAPIClient, @@ -1318,7 +1315,7 @@ async def test_create_job_with_multiple_disk_volumes_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_disk_api_client: DiskAPIClient, @@ -1378,7 +1375,7 @@ async def test_disk_volume_data_persisted_between_jobs( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_disk_api_client: DiskAPIClient, @@ -1439,7 +1436,7 @@ async def test_disk_volume_race_between_jobs_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_disk_api_client: DiskAPIClient, @@ -1509,7 +1506,7 @@ async def test_create_job_disk_volumes_unexisting_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -1541,10 +1538,10 @@ async def test_create_job_with_other_user_disk_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], test_cluster_name: str, regular_user_factory: UserFactory, - disk_client_factory: Callable[..., AsyncContextManager[DiskAPIClient]], + disk_client_factory: Callable[..., AbstractAsyncContextManager[DiskAPIClient]], read_only: bool, ) -> None: cluster = test_cluster_name @@ -1578,7 +1575,7 @@ async def test_create_job_with_disk_volume_wrong_scheme_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -1605,7 +1602,7 @@ async def test_create_job_with_disk_volume_wrong_cluster_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -1635,7 +1632,7 @@ async def test_create_job_with_disk_volume_invalid_mount_with_dots_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -1658,7 +1655,7 @@ async def test_create_job_with_disk_volume_invalid_mount_relative_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -1681,7 +1678,7 @@ async def test_create_job_disk_volumes_same_mount_points_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -1708,7 +1705,7 @@ async def test_create_job_with_secret_volumes_different_dirs_same_filenames_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -1749,7 +1746,7 @@ async def test_create_job_with_secret_env_and_secret_volumes_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -1820,7 +1817,7 @@ async def test_create_job_with_secret_same_secret_in_env_and_volumes_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -1878,7 +1875,7 @@ async def test_create_job_with_secret_same_secret_env_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -1924,7 +1921,7 @@ async def test_create_job_with_secret_same_secret_volumes_different_dirs_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -1970,7 +1967,7 @@ async def test_create_job_with_secret_same_secret_volumes_different_filenames_ok self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -2016,7 +2013,7 @@ async def test_create_job_with_secret_volumes_relative_directory_ok( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -2058,7 +2055,7 @@ async def test_create_job_with_secret_missing_all_user_secrets_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, secret_kind: str, @@ -2101,7 +2098,7 @@ async def test_create_job_with_secret_missing_all_requested_secrets_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -2152,7 +2149,7 @@ async def test_create_job_with_secret_env_missing_some_requested_secrets_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, regular_secrets_client: SecretsClient, @@ -2198,10 +2195,12 @@ async def test_create_job_with_secret_env_use_other_user_secret_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], test_cluster_name: str, regular_user_factory: UserFactory, - secrets_client_factory: Callable[..., AsyncContextManager[SecretsClient]], + secrets_client_factory: Callable[ + ..., AbstractAsyncContextManager[SecretsClient] + ], secret_kind: str, ) -> None: cluster = test_cluster_name @@ -2244,11 +2243,13 @@ async def test_create_job_with_secret_env_use_other_user_secret_success( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, test_cluster_name: str, regular_user_factory: UserFactory, - secrets_client_factory: Callable[..., AsyncContextManager[SecretsClient]], + secrets_client_factory: Callable[ + ..., AbstractAsyncContextManager[SecretsClient] + ], secret_kind: str, share_secret: Callable[..., Awaitable[None]], _run_job_with_secrets: Callable[..., Awaitable[None]], @@ -2326,7 +2327,7 @@ async def test_create_job_with_secret_env_wrong_scheme_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, secret_kind: str, @@ -2365,7 +2366,7 @@ async def test_create_job_with_secret_env_wrong_cluster_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, secret_kind: str, @@ -2403,7 +2404,7 @@ async def test_create_job_with_secret_volume_invalid_mount_with_dots_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2426,7 +2427,7 @@ async def test_create_job_with_secret_volume_invalid_mount_relative_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2449,7 +2450,7 @@ async def test_create_job_with_and_secret_volumes_same_mount_points_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2476,7 +2477,7 @@ async def test_create_job_set_max_run_time( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2494,7 +2495,7 @@ async def test_get_job_run_time_seconds( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2523,7 +2524,7 @@ async def test_create_job_volume_wrong_storage_scheme( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2548,7 +2549,7 @@ async def test_create_job_volume_wrong_cluster_name( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2574,7 +2575,7 @@ async def test_create_job_volume_wrong_path_with_dots( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2599,7 +2600,7 @@ async def test_create_job_volume_wrong_path_not_absolute( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2765,7 +2766,7 @@ async def test_allowed_image( @pytest.mark.asyncio async def test_create_job_unauthorized_no_token( - self, api: ApiConfig, client: aiohttp.ClientSession, job_submit: Dict[str, Any] + self, api: ApiConfig, client: aiohttp.ClientSession, job_submit: dict[str, Any] ) -> None: url = api.jobs_base_url async with client.post(url, json=job_submit) as response: @@ -2773,7 +2774,7 @@ async def test_create_job_unauthorized_no_token( @pytest.mark.asyncio async def test_create_job_unauthorized_invalid_token( - self, api: ApiConfig, client: aiohttp.ClientSession, job_submit: Dict[str, Any] + self, api: ApiConfig, client: aiohttp.ClientSession, job_submit: dict[str, Any] ) -> None: url = api.jobs_base_url headers = {"Authorization": "Bearer INVALID"} @@ -2785,7 +2786,7 @@ async def test_create_job_invalid_job_name( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2804,7 +2805,7 @@ async def test_create_job_user_has_unknown_cluster_name( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user_with_missing_cluster_name: _User, ) -> None: @@ -2823,7 +2824,7 @@ async def test_create_job_unknown_cluster_name( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2847,7 +2848,7 @@ async def test_create_job_no_clusters( api: ApiConfig, auth_api: AuthApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, admin_token: str, regular_user_with_missing_cluster_name: _User, @@ -2865,7 +2866,7 @@ async def test_create_job( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2915,7 +2916,7 @@ async def test_create_job_from_preset( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2950,7 +2951,7 @@ async def test_create_job_without_name_http_url_named_not_sent( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -2971,7 +2972,7 @@ async def test_create_multiple_jobs_with_same_name_fail( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, jobs_client: JobsClient, ) -> None: @@ -3006,7 +3007,7 @@ async def test_create_job_with_tags( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, jobs_client: JobsClient, ) -> None: @@ -3035,7 +3036,7 @@ async def test_create_job_has_credits( self, api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], test_cluster_name: str, @@ -3062,7 +3063,7 @@ async def test_create_job_no_credits( self, api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], regular_user_factory: UserFactory, credits: Decimal, cluster_name: str, @@ -3082,7 +3083,7 @@ async def test_create_multiple_jobs_with_same_name_after_first_finished( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, jobs_client: JobsClient, ) -> None: @@ -3148,7 +3149,7 @@ async def test_get_all_jobs_not_streamed( client: aiohttp.ClientSession, jobs_client: JobsClient, regular_user: _User, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], ) -> None: url = api.jobs_base_url headers = regular_user.headers @@ -3184,7 +3185,7 @@ async def test_get_all_jobs_filter_wrong_status( client: aiohttp.ClientSession, jobs_client: JobsClient, regular_user: _User, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], ) -> None: headers = regular_user.headers url = api.jobs_base_url @@ -3204,7 +3205,7 @@ async def test_get_all_jobs_filter_by_status_only_single_status_pending( client: aiohttp.ClientSession, jobs_client: JobsClient, regular_user: _User, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], ) -> None: url = api.jobs_base_url headers = regular_user.headers @@ -3234,7 +3235,7 @@ async def test_get_all_jobs_filter_by_tags( client: aiohttp.ClientSession, jobs_client: JobsClient, regular_user: _User, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], ) -> None: url = api.jobs_base_url headers = regular_user.headers @@ -3295,7 +3296,7 @@ async def test_get_all_jobs_filter_by_status_only( client: aiohttp.ClientSession, jobs_client: JobsClient, regular_user: _User, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], ) -> None: url = api.jobs_base_url headers = regular_user.headers @@ -3359,7 +3360,7 @@ async def test_get_all_jobs_filter_by_date_range( client: aiohttp.ClientSession, jobs_client: JobsClient, regular_user: _User, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], ) -> None: url = api.jobs_base_url headers = regular_user.headers @@ -3408,7 +3409,7 @@ async def test_get_all_jobs_filter_by_org( api: ApiConfig, client: aiohttp.ClientSession, jobs_client_factory: Callable[[_User], JobsClient], - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], regular_user_factory: UserFactory, ) -> None: url = api.jobs_base_url @@ -3462,12 +3463,12 @@ async def run_job( api: ApiConfig, client: aiohttp.ClientSession, jobs_client_factory: Callable[[_User], JobsClient], - ) -> AsyncIterator[Callable[[_User, Dict[str, Any], bool, bool], Awaitable[str]]]: + ) -> AsyncIterator[Callable[[_User, dict[str, Any], bool, bool], Awaitable[str]]]: cleanup_pairs = [] async def _impl( user: _User, - job_request: Dict[str, Any], + job_request: dict[str, Any], do_kill: bool = False, do_wait: bool = True, ) -> str: @@ -3529,9 +3530,9 @@ async def _impl( @pytest.fixture def create_job_request_with_name( - self, job_request_factory: Callable[[], Dict[str, Any]] - ) -> Iterator[Callable[[str], Dict[str, Any]]]: - def _impl(job_name: str) -> Dict[str, Any]: + self, job_request_factory: Callable[[], dict[str, Any]] + ) -> Iterator[Callable[[str], dict[str, Any]]]: + def _impl(job_name: str) -> dict[str, Any]: job_request = job_request_factory() job_request["container"]["command"] = "sleep 30m" job_request["name"] = job_name @@ -3541,9 +3542,9 @@ def _impl(job_name: str) -> Dict[str, Any]: @pytest.fixture def create_job_request_no_name( - self, job_request_factory: Callable[[], Dict[str, Any]] - ) -> Iterator[Callable[[], Dict[str, Any]]]: - def _impl() -> Dict[str, Any]: + self, job_request_factory: Callable[[], dict[str, Any]] + ) -> Iterator[Callable[[], dict[str, Any]]]: + def _impl() -> dict[str, Any]: job_request = job_request_factory() job_request["container"]["command"] = "sleep 30m" return job_request @@ -3557,10 +3558,10 @@ async def test_get_all_jobs_filter_by_job_name_and_statuses( client: aiohttp.ClientSession, regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], run_job: Callable[..., Awaitable[str]], - create_job_request_no_name: Callable[[], Dict[str, Any]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_req_no_name = create_job_request_no_name() @@ -3632,11 +3633,11 @@ async def test_get_all_jobs_filter_by_job_name_self_owner_and_statuses( client: aiohttp.ClientSession, regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], run_job: Callable[..., Awaitable[str]], share_job: Callable[[_User, _User, Any], Awaitable[None]], - create_job_request_no_name: Callable[[], Dict[str, Any]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_req_no_name = create_job_request_no_name() @@ -3715,11 +3716,11 @@ async def test_get_all_jobs_filter_by_job_name_another_owner_and_statuses( client: aiohttp.ClientSession, regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], run_job: Callable[..., Awaitable[str]], share_job: Callable[[_User, _User, Any], Awaitable[None]], - create_job_request_no_name: Callable[[], Dict[str, Any]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_req_no_name = create_job_request_no_name() @@ -3797,11 +3798,11 @@ async def test_get_all_jobs_filter_by_job_name_multiple_owners_and_statuses( client: aiohttp.ClientSession, regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], run_job: Callable[..., Awaitable[str]], share_job: Callable[[_User, _User, Any], Awaitable[None]], - create_job_request_no_name: Callable[[], Dict[str, Any]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_req_no_name = create_job_request_no_name() @@ -3979,7 +3980,7 @@ async def test_get_all_jobs_shared( jobs_client_factory: Callable[[_User], JobsClient], api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], regular_user_factory: UserFactory, auth_client: AuthClient, cluster_name: str, @@ -4033,7 +4034,7 @@ async def test_get_shared_job( jobs_client_factory: Callable[[_User], JobsClient], api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], regular_user_factory: UserFactory, auth_client: AuthClient, cluster_name: str, @@ -4083,7 +4084,7 @@ async def test_get_jobs_return_corrects_id( jobs_client: JobsClient, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, ) -> None: jobs_ids = [] @@ -4132,9 +4133,9 @@ async def test_get_jobs_by_name_preserves_chronological_order_without_statuses( jobs_client: JobsClient, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], regular_user: _User, - filters: Dict[str, Any], + filters: dict[str, Any], ) -> None: # unique job name generated per test-run is stored in "filters" job_submit["name"] = filters.get("name") @@ -4175,7 +4176,7 @@ async def test_get_job_by_cluster_name_and_statuses( jobs_client_factory: Callable[[_User], JobsClient], run_job: Callable[..., Awaitable[str]], share_job: Callable[[_User, _User, Any], Awaitable[None]], - create_job_request_no_name: Callable[[], Dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], ) -> None: job_req_no_name = create_job_request_no_name() usr1 = await regular_user_factory() @@ -4234,7 +4235,7 @@ async def test_get_job_by_hostname_self_owner( regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], run_job: Callable[..., Awaitable[str]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_name2 = "test-job-name2" @@ -4287,7 +4288,7 @@ async def test_get_job_by_hostname_another_owner( jobs_client_factory: Callable[[_User], JobsClient], run_job: Callable[..., Awaitable[str]], share_job: Callable[[_User, _User, Any], Awaitable[None]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_name2 = "test-job-name2" @@ -4325,7 +4326,7 @@ async def test_get_job_by_hostname_and_status( regular_user_factory: UserFactory, jobs_client_factory: Callable[[_User], JobsClient], run_job: Callable[..., Awaitable[str]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: job_name = "test-job-name" job_name2 = "test-job-name2" @@ -4365,7 +4366,7 @@ async def test_get_job_by_hostname_invalid_request( client: aiohttp.ClientSession, regular_user_factory: UserFactory, run_job: Callable[..., Awaitable[str]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], ) -> None: url = api.jobs_base_url job_name = "test-job-name" @@ -4397,7 +4398,7 @@ async def test_set_job_status_no_reason( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, run_job: Callable[..., Awaitable[str]], regular_user: _User, @@ -4432,7 +4433,7 @@ async def test_set_job_status_with_details( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, run_job: Callable[..., Awaitable[str]], regular_user: _User, @@ -4472,7 +4473,7 @@ async def test_set_job_status_wrong_status( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], run_job: Callable[..., Awaitable[str]], regular_user: _User, compute_user: _User, @@ -4490,7 +4491,7 @@ async def test_set_job_status_bad_transition( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, run_job: Callable[..., Awaitable[str]], regular_user: _User, @@ -4510,7 +4511,7 @@ async def test_set_job_status_unprivileged( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], run_job: Callable[..., Awaitable[str]], regular_user: _User, cluster_name: str, @@ -4532,7 +4533,7 @@ async def test_set_job_materialized( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, run_job: Callable[..., Awaitable[str]], regular_user: _User, @@ -4565,7 +4566,7 @@ async def test_update_max_run_time( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, run_job: Callable[..., Awaitable[str]], regular_user: _User, @@ -4591,7 +4592,7 @@ async def test_delete_job( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -4616,7 +4617,7 @@ async def test_delete_job_forbidden( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user_factory: UserFactory, regular_user: _User, @@ -4649,7 +4650,7 @@ async def test_delete_already_deleted( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -4683,7 +4684,7 @@ async def test_drop_job( self, api: ApiConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client: JobsClient, regular_user: _User, ) -> None: @@ -4711,7 +4712,7 @@ async def test_drop_job( async def test_create_validation_failure( self, api: ApiConfig, client: aiohttp.ClientSession, regular_user: _User ) -> None: - request_payload: Dict[str, Any] = {} + request_payload: dict[str, Any] = {} async with client.post( api.jobs_base_url, headers=regular_user.headers, json=request_payload ) as response: @@ -4722,7 +4723,7 @@ async def test_create_validation_failure( @pytest.mark.asyncio async def test_resolve_job_by_name( - self, job_submit: Dict[str, Any], jobs_client: JobsClient + self, job_submit: dict[str, Any], jobs_client: JobsClient ) -> None: job_name = f"test-job-name-{random_str()}" job_submit["name"] = job_name @@ -4745,8 +4746,8 @@ async def test_get_job_shared_by_name( jobs_client_factory: Callable[[_User], JobsClient], run_job: Callable[..., Awaitable[str]], share_job: Callable[[_User, _User, Any], Awaitable[None]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], - create_job_request_no_name: Callable[[], Dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], ) -> None: job_name = "test-job-name" job_name2 = "test-job-name2" @@ -4782,8 +4783,8 @@ async def test_delete_job_shared_by_name( jobs_client_factory: Callable[[_User], JobsClient], run_job: Callable[..., Awaitable[str]], share_job: Callable[..., Awaitable[None]], - create_job_request_with_name: Callable[[str], Dict[str, Any]], - create_job_request_no_name: Callable[[], Dict[str, Any]], + create_job_request_with_name: Callable[[str], dict[str, Any]], + create_job_request_no_name: Callable[[], dict[str, Any]], ) -> None: job_name = "test-job-name" job_name2 = "test-job-name2" @@ -5265,7 +5266,7 @@ async def test_enforce_runtime( auth_api: AuthApiConfig, config: Config, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user: _User, ) -> None: @@ -5326,7 +5327,7 @@ async def test_enforce_billing( config: Config, cluster_config: ClusterConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: Callable[[], Awaitable[_User]], admin_client: AdminClient, @@ -5334,7 +5335,7 @@ async def test_enforce_billing( ) -> None: durations = [30, 25, 20, 15, 10, 5] - test_jobs: List[Tuple[Dict[str, Any], _User, JobsClient, int]] = [] + test_jobs: list[tuple[dict[str, Any], _User, JobsClient, int]] = [] for duration in durations: job_submit["container"]["command"] = f"sleep {duration}s" user = await regular_user_factory() @@ -5357,7 +5358,7 @@ async def test_enforce_billing( # Wait for 7 ticks for jobs to become charged await asyncio.sleep(config.job_policy_enforcer.interval_sec * 7) - user_to_charge: Dict[str, Decimal] = defaultdict(Decimal) + user_to_charge: dict[str, Decimal] = defaultdict(Decimal) for cluster_user in await admin_client.list_cluster_users(cluster_name): user_to_charge[cluster_user.user_name] = cluster_user.balance.spent_credits @@ -5381,7 +5382,7 @@ async def test_enforce_billing_with_org( config: Config, cluster_config: ClusterConfig, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, admin_client: AdminClient, @@ -5389,7 +5390,7 @@ async def test_enforce_billing_with_org( ) -> None: durations = [30, 25, 20, 15, 10, 5] - test_jobs: List[Tuple[Dict[str, Any], _User, JobsClient, int]] = [] + test_jobs: list[tuple[dict[str, Any], _User, JobsClient, int]] = [] for duration in durations: job_submit["container"]["command"] = f"sleep {duration}s" user = await regular_user_factory( @@ -5414,7 +5415,7 @@ async def test_enforce_billing_with_org( # Wait for 7 ticks for jobs to become charged await asyncio.sleep(config.job_policy_enforcer.interval_sec * 7) - user_to_charge: Dict[str, Decimal] = defaultdict(Decimal) + user_to_charge: dict[str, Decimal] = defaultdict(Decimal) for cluster_user in await admin_client.list_cluster_users(cluster_name): if cluster_user.org_name == "org": user_to_charge[ @@ -5443,7 +5444,7 @@ async def test_enforce_retention( auth_api: AuthApiConfig, config: Config, client: aiohttp.ClientSession, - job_submit: Dict[str, Any], + job_submit: dict[str, Any], jobs_client_factory: Callable[[_User], JobsClient], regular_user: _User, ) -> None: diff --git a/tests/integration/test_config_client.py b/tests/integration/test_config_client.py index 8f9474496..07125f2a2 100644 --- a/tests/integration/test_config_client.py +++ b/tests/integration/test_config_client.py @@ -1,5 +1,6 @@ +from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Dict, List +from typing import Any import aiohttp import pytest @@ -11,7 +12,7 @@ @pytest.fixture -def cluster_configs_payload() -> List[Dict[str, Any]]: +def cluster_configs_payload() -> list[dict[str, Any]]: return [ { "name": "cluster_name", @@ -49,7 +50,7 @@ def cluster_configs_payload() -> List[Dict[str, Any]]: ] -async def create_config_app(payload: List[Dict[str, Any]]) -> aiohttp.web.Application: +async def create_config_app(payload: list[dict[str, Any]]) -> aiohttp.web.Application: app = aiohttp.web.Application() async def handle(request: aiohttp.web.Request) -> aiohttp.web.Response: @@ -63,7 +64,7 @@ async def handle(request: aiohttp.web.Request) -> aiohttp.web.Response: @asynccontextmanager async def create_config_api( - cluster_configs_payload: List[Dict[str, Any]] + cluster_configs_payload: list[dict[str, Any]] ) -> AsyncIterator[URL]: app = await create_config_app(cluster_configs_payload) runner = ApiRunner(app, port=8082) @@ -75,7 +76,7 @@ async def create_config_api( class TestConfigClient: @pytest.mark.asyncio async def test_valid_cluster_configs( - self, cluster_configs_payload: List[Dict[str, Any]] + self, cluster_configs_payload: list[dict[str, Any]] ) -> None: async with create_config_api(cluster_configs_payload) as url: async with ConfigClient(base_url=url) as client: @@ -85,7 +86,7 @@ async def test_valid_cluster_configs( @pytest.mark.asyncio async def test_client_skips_invalid_cluster_configs( - self, cluster_configs_payload: List[Dict[str, Any]] + self, cluster_configs_payload: list[dict[str, Any]] ) -> None: cluster_configs_payload.append({}) async with create_config_api(cluster_configs_payload) as url: diff --git a/tests/integration/test_jobs_storage.py b/tests/integration/test_jobs_storage.py index 69390dee6..749848608 100644 --- a/tests/integration/test_jobs_storage.py +++ b/tests/integration/test_jobs_storage.py @@ -1,6 +1,6 @@ from datetime import timedelta from itertools import islice -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import pytest from sqlalchemy.ext.asyncio import AsyncEngine @@ -488,7 +488,7 @@ async def test_get_all_filter_by_tags(self, storage: JobsStorage) -> None: @pytest.mark.asyncio @pytest.mark.parametrize("statuses", [(), (JobStatus.PENDING, JobStatus.RUNNING)]) async def test_get_all_filter_by_date_range( - self, statuses: Tuple[JobStatus, ...], storage: JobsStorage + self, statuses: tuple[JobStatus, ...], storage: JobsStorage ) -> None: t1 = current_datetime_factory() job1 = self._create_job() @@ -572,7 +572,7 @@ async def test_get_all_filter_by_date_range( job_ids = [job.id for job in await storage.get_all_jobs(job_filter)] assert job_ids == [] - async def prepare_filtering_test(self, storage: JobsStorage) -> List[JobRecord]: + async def prepare_filtering_test(self, storage: JobsStorage) -> list[JobRecord]: jobs = [ # no name: self._create_pending_job(owner="user1", job_name=None), @@ -706,12 +706,12 @@ async def test_get_all_filter_by_multiple_owners( ) async def test_get_all_with_filters( self, - owners: Tuple[str, ...], + owners: tuple[str, ...], name: Optional[str], - statuses: Tuple[JobStatus, ...], + statuses: tuple[JobStatus, ...], storage: JobsStorage, ) -> None: - def sort_jobs_as_primitives(array: List[JobRecord]) -> List[Dict[str, Any]]: + def sort_jobs_as_primitives(array: list[JobRecord]) -> list[dict[str, Any]]: return sorted( (job.to_primitive() for job in array), key=lambda job: job["id"] ) @@ -745,7 +745,7 @@ def sort_jobs_as_primitives(array: List[JobRecord]) -> List[Dict[str, Any]]: async def test_get_all_filter_by_name_with_no_owner( self, name: Optional[str], - statuses: Tuple[JobStatus, ...], + statuses: tuple[JobStatus, ...], storage: JobsStorage, ) -> None: jobs = await self.prepare_filtering_test(storage) @@ -917,7 +917,7 @@ async def test_get_all_filter_by_hostname_and_status( async def prepare_filtering_test_different_clusters( self, storage: JobsStorage - ) -> List[JobRecord]: + ) -> list[JobRecord]: jobs = [ self._create_running_job(owner="user1", cluster_name="test-cluster"), self._create_succeeded_job( @@ -958,7 +958,7 @@ async def test_get_all_filter_by_cluster(self, storage: JobsStorage) -> None: async def prepare_filtering_test_different_orgs( self, storage: JobsStorage - ) -> List[JobRecord]: + ) -> list[JobRecord]: jobs = [ self._create_running_job( owner="user1", cluster_name="test-cluster", org_name=None diff --git a/tests/integration/test_kube_orchestrator.py b/tests/integration/test_kube_orchestrator.py index f991e9529..d1da7da56 100644 --- a/tests/integration/test_kube_orchestrator.py +++ b/tests/integration/test_kube_orchestrator.py @@ -3,20 +3,11 @@ import shlex import time import uuid -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import replace from pathlib import PurePath -from typing import ( - Any, - AsyncContextManager, - AsyncIterator, - Awaitable, - Callable, - Dict, - Iterator, - List, - Optional, -) +from typing import Any, Optional from unittest import mock import aiohttp @@ -883,7 +874,7 @@ async def test_list_services(self, kube_client: KubeClient) -> None: labels1 = {"label": f"value-{uuid.uuid4()}"} labels2 = {"label": f"value-{uuid.uuid4()}"} - def _gen_for_labels(labels: Dict[str, str]) -> List[Service]: + def _gen_for_labels(labels: dict[str, str]) -> list[Service]: return [ Service(name=f"job-{uuid.uuid4()}", target_port=8080, labels=labels) for _ in range(5) @@ -2284,7 +2275,7 @@ async def kube_orchestrator( @pytest.fixture async def start_job( self, kube_client: MyKubeClient - ) -> Callable[..., AsyncContextManager[MyJob]]: + ) -> Callable[..., AbstractAsyncContextManager[MyJob]]: @asynccontextmanager async def _create( kube_orchestrator: KubeOrchestrator, @@ -2330,7 +2321,7 @@ async def _create( async def test_unschedulable_job( self, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: with pytest.raises(JobError, match="Job will not fit into cluster"): async with start_job(kube_orchestrator, cpu=100, memory_mb=32): @@ -2341,7 +2332,7 @@ async def test_cpu_job( self, kube_client: MyKubeClient, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: async with start_job(kube_orchestrator, cpu=0.1, memory_mb=32) as job: await kube_client.wait_pod_scheduled(job.id, "cpu-small") @@ -2363,7 +2354,7 @@ async def test_cpu_job_on_tpu_node( self, kube_client: MyKubeClient, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: async with start_job(kube_orchestrator, cpu=3, memory_mb=32) as job: await kube_client.wait_pod_scheduled(job.id, "cpu-large-tpu") @@ -2389,7 +2380,7 @@ async def test_cpu_job_not_scheduled_on_gpu_node( self, kube_client: MyKubeClient, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: with pytest.raises(JobError, match="Job will not fit into cluster"): async with start_job(kube_orchestrator, cpu=7, memory_mb=32): @@ -2400,7 +2391,7 @@ async def test_gpu_job( self, kube_client: MyKubeClient, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: async with start_job( kube_orchestrator, @@ -2428,7 +2419,7 @@ async def test_scheduled_job_on_not_preemptible_node( self, kube_client: MyKubeClient, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: async with start_job( kube_orchestrator, @@ -2458,7 +2449,7 @@ async def test_preemptible_job_on_preemptible_node( self, kube_client: MyKubeClient, kube_orchestrator: KubeOrchestrator, - start_job: Callable[..., AsyncContextManager[MyJob]], + start_job: Callable[..., AbstractAsyncContextManager[MyJob]], ) -> None: async with start_job( kube_orchestrator, @@ -3013,8 +3004,8 @@ async def create_network_policy( self, kube_client: KubeClient, delete_network_policy_later: Callable[[str], Awaitable[None]], - ) -> Callable[[str], Awaitable[Dict[str, Any]]]: - async def _f(job_id: str) -> Dict[str, Any]: + ) -> Callable[[str], Awaitable[dict[str, Any]]]: + async def _f(job_id: str) -> dict[str, Any]: np_name = f"networkpolicy-{uuid.uuid4().hex[:6]}" labels = {"platform.neuromation.io/job": job_id} @@ -3030,7 +3021,7 @@ async def _f(job_id: str) -> Dict[str, Any]: async def test_get_all_job_resources_links_job_network_policy( self, kube_client: KubeClient, - create_network_policy: Callable[[str], Awaitable[Dict[str, Any]]], + create_network_policy: Callable[[str], Awaitable[dict[str, Any]]], ) -> None: job_id = f"job-{uuid.uuid4()}" payload = await create_network_policy(job_id) @@ -3043,7 +3034,7 @@ async def test_get_all_job_resources_links_job_network_policy( async def test_delete_resource_by_link_network_policy( self, kube_client: KubeClient, - create_network_policy: Callable[[str], Awaitable[Dict[str, Any]]], + create_network_policy: Callable[[str], Awaitable[dict[str, Any]]], ) -> None: job_id = f"job-{uuid.uuid4()}" payload = await create_network_policy(job_id) @@ -3059,7 +3050,7 @@ async def test_delete_resource_by_link_network_policy( @pytest.fixture async def mock_kubernetes_server() -> AsyncIterator[ApiConfig]: async def _get_pod(request: web.Request) -> web.Response: - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "kind": "Pod", "metadata": { "name": "testname", diff --git a/tests/integration/test_notifications.py b/tests/integration/test_notifications.py index 2762ad191..43e08e9a8 100644 --- a/tests/integration/test_notifications.py +++ b/tests/integration/test_notifications.py @@ -1,5 +1,6 @@ +from collections.abc import AsyncIterator, Awaitable, Callable from decimal import Decimal -from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Set +from typing import Any from unittest import mock import aiohttp.web @@ -20,7 +21,7 @@ async def test_not_sent_has_credits( self, api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], jobs_client: Callable[[], Any], regular_user_factory: UserFactory, mock_notifications_server: NotificationsServer, @@ -44,7 +45,7 @@ async def test_sent_if_no_credits( self, api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], jobs_client: Callable[[], Any], regular_user_factory: UserFactory, mock_notifications_server: NotificationsServer, @@ -75,12 +76,12 @@ async def run_job( api: ApiConfig, client: aiohttp.ClientSession, jobs_client_factory: Callable[[_User], JobsClient], - ) -> AsyncIterator[Callable[[_User, Dict[str, Any], bool, bool], Awaitable[str]]]: + ) -> AsyncIterator[Callable[[_User, dict[str, Any], bool, bool], Awaitable[str]]]: cleanup_pairs = [] async def _impl( user: _User, - job_request: Dict[str, Any], + job_request: dict[str, Any], wait_for_start: bool = True, do_kill: bool = False, ) -> str: @@ -113,7 +114,7 @@ async def test_not_sent_job_creating_failed( self, api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[str], Dict[str, Any]], + job_request_factory: Callable[[str], dict[str, Any]], regular_user_factory: UserFactory, mock_notifications_server: NotificationsServer, ) -> None: @@ -133,7 +134,7 @@ async def test_not_sent_job_creating_failed( async def test_succeeded_job_workflow( self, api: ApiConfig, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, mock_notifications_server: NotificationsServer, @@ -148,7 +149,7 @@ async def test_succeeded_job_workflow( await jobs_client.delete_job(job_id) await api.runner.close() - states: Set[str] = set() + states: set[str] = set() for (slug, payload) in mock_notifications_server.requests: if slug != "job-transition": raise AssertionError(f"Unexpected Notification: {slug} : {payload}") @@ -174,7 +175,7 @@ async def test_succeeded_job_workflow( async def test_failed_job_workflow( self, api: ApiConfig, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, mock_notifications_server: NotificationsServer, @@ -189,7 +190,7 @@ async def test_failed_job_workflow( await jobs_client.long_polling_by_job_id(job_id, "failed") await api.runner.close() - states: Set[str] = set() + states: set[str] = set() for (slug, payload) in mock_notifications_server.requests: if slug != "job-transition": raise AssertionError(f"Unexpected Notification: {slug} : {payload}") @@ -217,7 +218,7 @@ async def test_sent_if_credits_less_then_threshold( config: Config, api: ApiConfig, client: aiohttp.ClientSession, - job_request_factory: Callable[[], Dict[str, Any]], + job_request_factory: Callable[[], dict[str, Any]], jobs_client_factory: Callable[[_User], JobsClient], regular_user_factory: UserFactory, mock_notifications_server: NotificationsServer, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1a55b3eaa..e6d5fcc10 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,23 +1,11 @@ import asyncio from collections import defaultdict +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import partial from pathlib import Path -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Dict, - Iterator, - List, - Optional, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Awaitable, Optional, Union, cast import pytest from aiohttp import ClientResponseError @@ -84,9 +72,9 @@ def __init__(self, config: ClusterConfig) -> None: self._mock_status_to_return = JobStatus.PENDING self._mock_reason_to_return: Optional[str] = JobStatusReason.CONTAINER_CREATING self._mock_exit_code_to_return: Optional[int] = None - self._mock_statuses: Dict[str, JobStatus] = {} - self._mock_reasons: Dict[str, Optional[str]] = {} - self._mock_exit_codes: Dict[str, Optional[int]] = {} + self._mock_statuses: dict[str, JobStatus] = {} + self._mock_reasons: dict[str, Optional[str]] = {} + self._mock_exit_codes: dict[str, Optional[int]] = {} self.raise_on_get_job_status = False self.raise_on_start_job_status = False self.get_job_status_exc_factory = self._create_get_job_status_exc @@ -95,7 +83,7 @@ def __init__(self, config: ClusterConfig) -> None: self.current_datetime_factory: Callable[[], datetime] = partial( datetime.now, timezone.utc ) - self._successfully_deleted_jobs: List[Job] = [] + self._successfully_deleted_jobs: list[Job] = [] @property def config(self) -> OrchestratorConfig: @@ -167,15 +155,15 @@ def update_exit_code_to_return_single( ) -> None: self._mock_exit_codes[job_id] = new_exit_code - def get_successfully_deleted_jobs(self) -> List[Job]: + def get_successfully_deleted_jobs(self) -> list[Job]: return self._successfully_deleted_jobs async def get_missing_secrets( - self, user_name: str, secret_names: List[str] - ) -> List[str]: + self, user_name: str, secret_names: list[str] + ) -> list[str]: pass - async def get_missing_disks(self, disks: List[Disk]) -> List[Disk]: + async def get_missing_disks(self, disks: list[Disk]) -> list[Disk]: pass @@ -192,7 +180,7 @@ async def set_job(self, job: JobRecord) -> None: class MockNotificationsClient(NotificationsClient): def __init__(self) -> None: - self._sent_notifications: List[Notification] = [] + self._sent_notifications: list[Notification] = [] pass async def notify(self, notification: Notification) -> None: @@ -205,7 +193,7 @@ async def close(self) -> None: pass @property - def sent_notifications(self) -> List[Notification]: + def sent_notifications(self) -> list[Notification]: return self._sent_notifications @@ -214,18 +202,18 @@ def __init__(self) -> None: self.user_to_return = AuthUser( name="testuser", ) - self._grants: List[Tuple[str, Sequence[Permission]]] = [] - self._revokes: List[Tuple[str, Sequence[str]]] = [] + self._grants: list[tuple[str, Sequence[Permission]]] = [] + self._revokes: list[tuple[str, Sequence[str]]] = [] async def get_user(self, name: str, token: Optional[str] = None) -> AuthUser: return self.user_to_return @property - def grants(self) -> List[Tuple[str, Sequence[Permission]]]: + def grants(self) -> list[tuple[str, Sequence[Permission]]]: return self._grants @property - def revokes(self) -> List[Tuple[str, Sequence[str]]]: + def revokes(self) -> list[tuple[str, Sequence[str]]]: return self._revokes async def grant_user_permissions( @@ -249,16 +237,16 @@ async def get_user_token( class MockAdminClient(AdminClientDummy): def __init__(self) -> None: - self.users: Dict[str, User] = {} - self.cluster_users: Dict[str, List[ClusterUser]] = defaultdict(list) - self.org_clusters: Dict[str, List[OrgCluster]] = defaultdict(list) - self.spending_log: List[ - Tuple[str, Optional[str], str, Decimal, Optional[str]] + self.users: dict[str, User] = {} + self.cluster_users: dict[str, list[ClusterUser]] = defaultdict(list) + self.org_clusters: dict[str, list[OrgCluster]] = defaultdict(list) + self.spending_log: list[ + tuple[str, Optional[str], str, Decimal, Optional[str]] ] = [] - self.debts_log: List[Tuple[str, str, Decimal, str]] = [] + self.debts_log: list[tuple[str, str, Decimal, str]] = [] self.raise_404: bool = False - async def get_user_with_clusters(self, name: str) -> Tuple[User, List[ClusterUser]]: + async def get_user_with_clusters(self, name: str) -> tuple[User, list[ClusterUser]]: if name not in self.users: raise ClientResponseError(None, (), status=404) # type: ignore return self.users[name], self.cluster_users[name] @@ -331,10 +319,10 @@ def __init__(self, jobs_service: JobsService, jobs_storage: JobsStorage): self._jobs_service = jobs_service self._jobs_storage = jobs_storage - async def get_unfinished_jobs(self) -> List[JobRecord]: + async def get_unfinished_jobs(self) -> list[JobRecord]: return await self._jobs_storage.get_unfinished_jobs() - async def get_jobs_for_deletion(self, *, delay: timedelta) -> List[JobRecord]: + async def get_jobs_for_deletion(self, *, delay: timedelta) -> list[JobRecord]: return await self._jobs_storage.get_jobs_for_deletion(delay=delay) async def push_status(self, job_id: str, status: JobStatusItem) -> None: @@ -348,7 +336,7 @@ async def set_materialized(self, job_id: str, materialized: bool) -> None: @pytest.fixture def job_request_factory() -> Callable[[], JobRequest]: def factory(with_gpu: bool = False) -> JobRequest: - cont_kwargs: Dict[str, Any] = {"cpu": 1, "memory_mb": 128} + cont_kwargs: dict[str, Any] = {"cpu": 1, "memory_mb": 128} if with_gpu: cont_kwargs["gpu"] = 1 cont_kwargs["gpu_model_id"] = "nvidia-tesla-k80" @@ -537,7 +525,7 @@ def event_loop() -> Iterator[asyncio.AbstractEventLoop]: UserFactory = Callable[ - [str, List[Union[Tuple[str, Balance, Quota], Tuple[str, str, Balance, Quota]]]], + [str, list[Union[tuple[str, Balance, Quota], tuple[str, str, Balance, Quota]]]], Awaitable[AuthUser], ] @@ -548,18 +536,18 @@ def user_factory( ) -> UserFactory: async def _factory( name: str, - clusters: List[ - Union[Tuple[str, Balance, Quota], Tuple[str, str, Balance, Quota]] + clusters: list[ + Union[tuple[str, Balance, Quota], tuple[str, str, Balance, Quota]] ], ) -> AuthUser: mock_admin_client.users[name] = User(name=name, email=f"{name}@domain.com") for entry in clusters: org_name: Optional[str] = None if len(entry) == 3: - cluster, balance, quota = cast(Tuple[str, Balance, Quota], entry) + cluster, balance, quota = cast(tuple[str, Balance, Quota], entry) else: cluster, org_name, balance, quota = cast( - Tuple[str, str, Balance, Quota], entry + tuple[str, str, Balance, Quota], entry ) mock_admin_client.cluster_users[name].append( ClusterUser( @@ -576,14 +564,14 @@ async def _factory( return _factory -OrgFactory = Callable[[str, List[Tuple[str, Balance, Quota]]], Awaitable[str]] +OrgFactory = Callable[[str, list[tuple[str, Balance, Quota]]], Awaitable[str]] @pytest.fixture def org_factory( mock_admin_client: MockAdminClient, ) -> OrgFactory: - async def _factory(name: str, clusters: List[Tuple[str, Balance, Quota]]) -> str: + async def _factory(name: str, clusters: list[tuple[str, Balance, Quota]]) -> str: mock_admin_client.users[name] = User(name=name, email=f"{name}@domain.com") for cluster, balance, quota in clusters: mock_admin_client.org_clusters[name].append( diff --git a/tests/unit/test_billing_log_service.py b/tests/unit/test_billing_log_service.py index 3b5b07832..36a0ba268 100644 --- a/tests/unit/test_billing_log_service.py +++ b/tests/unit/test_billing_log_service.py @@ -1,7 +1,8 @@ import asyncio +from collections.abc import AsyncIterator, Callable, Mapping from datetime import datetime, timezone from decimal import Decimal -from typing import Any, AsyncIterator, Callable, Mapping +from typing import Any import pytest from neuro_admin_client import AdminClient, Balance, Quota diff --git a/tests/unit/test_cluster_config_factory.py b/tests/unit/test_cluster_config_factory.py index 266149028..3dec8d158 100644 --- a/tests/unit/test_cluster_config_factory.py +++ b/tests/unit/test_cluster_config_factory.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from decimal import Decimal -from typing import Any, Dict, List, Sequence +from typing import Any import pytest from yarl import URL @@ -9,7 +10,7 @@ @pytest.fixture -def host_storage_payload() -> Dict[str, Any]: +def host_storage_payload() -> dict[str, Any]: return { "storage": { "host": {"mount_path": "/host/mount/path"}, @@ -19,7 +20,7 @@ def host_storage_payload() -> Dict[str, Any]: @pytest.fixture -def nfs_storage_payload() -> Dict[str, Any]: +def nfs_storage_payload() -> dict[str, Any]: return { "storage": { "nfs": {"server": "127.0.0.1", "export_path": "/nfs/export/path"}, @@ -29,7 +30,7 @@ def nfs_storage_payload() -> Dict[str, Any]: @pytest.fixture -def pvc_storage_payload() -> Dict[str, Any]: +def pvc_storage_payload() -> dict[str, Any]: return { "storage": { "pvc": {"name": "platform-storage"}, @@ -39,7 +40,7 @@ def pvc_storage_payload() -> Dict[str, Any]: @pytest.fixture -def clusters_payload(nfs_storage_payload: Dict[str, Any]) -> List[Dict[str, Any]]: +def clusters_payload(nfs_storage_payload: dict[str, Any]) -> list[dict[str, Any]]: return [ { "name": "cluster_name", @@ -193,7 +194,7 @@ def users_url() -> URL: class TestClusterConfigFactory: def test_valid_cluster_config( - self, clusters_payload: Sequence[Dict[str, Any]] + self, clusters_payload: Sequence[dict[str, Any]] ) -> None: storage_payload = clusters_payload[0]["storage"] registry_payload = clusters_payload[0]["registry"] @@ -297,7 +298,7 @@ def test_valid_cluster_config( assert orchestrator.tpu_ipv4_cidr_block == "1.1.1.1/32" def test_orchestrator_resource_presets( - self, clusters_payload: Sequence[Dict[str, Any]] + self, clusters_payload: Sequence[dict[str, Any]] ) -> None: factory = ClusterConfigFactory() clusters_payload[0]["orchestrator"]["resource_presets"] = [ @@ -342,7 +343,7 @@ def test_orchestrator_resource_presets( ] def test_orchestrator_job_schedule_settings_default( - self, clusters_payload: Sequence[Dict[str, Any]] + self, clusters_payload: Sequence[dict[str, Any]] ) -> None: orchestrator = clusters_payload[0]["orchestrator"] del orchestrator["job_schedule_timeout_s"] @@ -355,7 +356,7 @@ def test_orchestrator_job_schedule_settings_default( assert clusters[0].orchestrator.job_schedule_scaleup_timeout == 900 def test_factory_skips_invalid_cluster_configs( - self, clusters_payload: List[Dict[str, Any]] + self, clusters_payload: list[dict[str, Any]] ) -> None: clusters_payload.append({}) factory = ClusterConfigFactory() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ececd67f2..f0273eaca 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,7 +1,6 @@ from datetime import timedelta from decimal import Decimal from pathlib import PurePath -from typing import Dict import pytest from yarl import URL @@ -234,7 +233,7 @@ def test_create_secret_volume(self, registry_config: RegistryConfig) -> None: class TestEnvironConfigFactory: def test_create_key_error(self) -> None: - environ: Dict[str, str] = {} + environ: dict[str, str] = {} with pytest.raises(KeyError): EnvironConfigFactory(environ=environ).create() diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 2086b537a..6ab14509e 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -1,7 +1,8 @@ import dataclasses +from collections.abc import Callable from datetime import datetime, timedelta, timezone from pathlib import PurePath -from typing import Any, Callable, Dict +from typing import Any from unittest import mock import pytest @@ -436,7 +437,7 @@ def test_from_payload_build_with_tty(self) -> None: @pytest.fixture -def job_request_payload() -> Dict[str, Any]: +def job_request_payload() -> dict[str, Any]: return { "job_id": "testjob", "description": "Description of the testjob", @@ -459,7 +460,7 @@ def job_request_payload() -> Dict[str, Any]: @pytest.fixture -def job_payload(job_request_payload: Any) -> Dict[str, Any]: +def job_payload(job_request_payload: Any) -> dict[str, Any]: finished_at_str = datetime.now(timezone.utc).isoformat() return { "id": "testjob", @@ -472,7 +473,7 @@ def job_payload(job_request_payload: Any) -> Dict[str, Any]: @pytest.fixture -def job_request_payload_with_shm(job_request_payload: Dict[str, Any]) -> Dict[str, Any]: +def job_request_payload_with_shm(job_request_payload: dict[str, Any]) -> dict[str, Any]: data = job_request_payload data["container"]["resources"]["shm"] = True return data @@ -973,7 +974,7 @@ def test_to_primitive_with_org_name( assert primitive["org_name"] == "10250zxvgew" def test_from_primitive( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -999,7 +1000,7 @@ def test_from_primitive( assert job.restart_policy == JobRestartPolicy.NEVER def test_from_primitive_check_name( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -1015,7 +1016,7 @@ def test_from_primitive_check_name( assert job.name == "test-job-name" def test_from_primitive_with_preset_name( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -1030,7 +1031,7 @@ def test_from_primitive_with_preset_name( assert job.preset_name == "cpu-small" def test_from_primitive_with_tags( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: tags = ["tag1", "tag2"] payload = { @@ -1047,7 +1048,7 @@ def test_from_primitive_with_tags( assert job.tags == tags def test_from_primitive_with_statuses( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: finished_at_str = datetime.now(timezone.utc).isoformat() payload = { @@ -1071,7 +1072,7 @@ def test_from_primitive_with_statuses( assert job.preemptible_node def test_from_primitive_with_cluster_name( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -1095,7 +1096,7 @@ def test_from_primitive_with_cluster_name( assert not job.preemptible_node def test_from_primitive_with_entrypoint_without_command( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["entrypoint"] = "/script.sh" job_request_payload["container"].pop("command", None) @@ -1113,7 +1114,7 @@ def test_from_primitive_with_entrypoint_without_command( assert job.request.container.entrypoint == "/script.sh" def test_from_primitive_without_entrypoint_with_command( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"].pop("entrypoint", None) job_request_payload["container"]["command"] = "arg1 arg2 arg3" @@ -1131,7 +1132,7 @@ def test_from_primitive_without_entrypoint_with_command( assert job.request.container.entrypoint is None def test_from_primitive_without_entrypoint_without_command( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"].pop("entrypoint", None) job_request_payload["container"].pop("command", None) @@ -1149,7 +1150,7 @@ def test_from_primitive_without_entrypoint_without_command( assert job.request.container.entrypoint is None def test_from_primitive_with_entrypoint_with_command( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["entrypoint"] = "/script.sh" job_request_payload["container"]["command"] = "arg1 arg2 arg3" @@ -1167,7 +1168,7 @@ def test_from_primitive_with_entrypoint_with_command( assert job.request.container.entrypoint == "/script.sh" def test_from_primitive_with_max_run_time_minutes( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -1183,7 +1184,7 @@ def test_from_primitive_with_max_run_time_minutes( assert job.max_run_time_minutes == 100 def test_from_primitive_with_max_run_time_minutes_none( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -1199,7 +1200,7 @@ def test_from_primitive_with_max_run_time_minutes_none( assert job.max_run_time_minutes is None def test_from_primitive_with_org_name( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: payload = { "id": "testjob", @@ -1259,7 +1260,7 @@ def test_to_uri_no_owner( assert job.to_uri() == URL(f"job://test-cluster/{job.id}") def test_to_and_from_primitive( - self, mock_orchestrator: MockOrchestrator, job_request_payload: Dict[str, Any] + self, mock_orchestrator: MockOrchestrator, job_request_payload: dict[str, Any] ) -> None: finished_at_str = datetime.now(timezone.utc).isoformat() current_status_item = { @@ -1292,7 +1293,7 @@ def test_to_and_from_primitive( class TestJobRequest: - def test_to_primitive(self, job_request_payload: Dict[str, Any]) -> None: + def test_to_primitive(self, job_request_payload: dict[str, Any]) -> None: container = Container( image="testimage", env={"testvar": "testval"}, @@ -1312,7 +1313,7 @@ def test_to_primitive(self, job_request_payload: Dict[str, Any]) -> None: assert request.to_primitive() == job_request_payload def test_to_primitive_with_entrypoint( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["entrypoint"] = "/bin/ls" container = Container( @@ -1335,7 +1336,7 @@ def test_to_primitive_with_entrypoint( assert request.to_primitive() == job_request_payload def test_to_primitive_with_working_dir( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["working_dir"] = "/working/dir" container = Container( @@ -1357,7 +1358,7 @@ def test_to_primitive_with_working_dir( ) assert request.to_primitive() == job_request_payload - def test_to_primitive_with_tty(self, job_request_payload: Dict[str, Any]) -> None: + def test_to_primitive_with_tty(self, job_request_payload: dict[str, Any]) -> None: job_request_payload["container"]["tty"] = True container = Container( @@ -1379,7 +1380,7 @@ def test_to_primitive_with_tty(self, job_request_payload: Dict[str, Any]) -> Non ) assert request.to_primitive() == job_request_payload - def test_from_primitive(self, job_request_payload: Dict[str, Any]) -> None: + def test_from_primitive(self, job_request_payload: dict[str, Any]) -> None: request = JobRequest.from_primitive(job_request_payload) assert request.job_id == "testjob" assert request.description == "Description of the testjob" @@ -1397,7 +1398,7 @@ def test_from_primitive(self, job_request_payload: Dict[str, Any]) -> None: assert request.container == expected_container def test_from_primitive_with_working_dir( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["working_dir"] = "/working/dir" request = JobRequest.from_primitive(job_request_payload) @@ -1418,7 +1419,7 @@ def test_from_primitive_with_working_dir( assert request.container == expected_container def test_from_primitive_with_shm( - self, job_request_payload_with_shm: Dict[str, Any] + self, job_request_payload_with_shm: dict[str, Any] ) -> None: request = JobRequest.from_primitive(job_request_payload_with_shm) assert request.job_id == "testjob" @@ -1436,12 +1437,12 @@ def test_from_primitive_with_shm( ) assert request.container == expected_container - def test_to_and_from_primitive(self, job_request_payload: Dict[str, Any]) -> None: + def test_to_and_from_primitive(self, job_request_payload: dict[str, Any]) -> None: actual = JobRequest.to_primitive(JobRequest.from_primitive(job_request_payload)) assert actual == job_request_payload def test_to_and_from_primitive_with_shm( - self, job_request_payload_with_shm: Dict[str, Any] + self, job_request_payload_with_shm: dict[str, Any] ) -> None: actual = JobRequest.to_primitive( JobRequest.from_primitive(job_request_payload_with_shm) @@ -1449,7 +1450,7 @@ def test_to_and_from_primitive_with_shm( assert actual == job_request_payload_with_shm def test_to_and_from_primitive_with_tpu( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["resources"]["tpu"] = { "type": "v2-8", @@ -1459,7 +1460,7 @@ def test_to_and_from_primitive_with_tpu( assert actual == job_request_payload def test_to_and_from_primitive_with_secret_env( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["secret_env"] = { "ENV_SECRET1": "secret://clustername/username/key1", @@ -1469,7 +1470,7 @@ def test_to_and_from_primitive_with_secret_env( assert actual == job_request_payload def test_to_and_from_primitive_with_secret_volumes( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["secret_volumes"] = [ { @@ -1485,7 +1486,7 @@ def test_to_and_from_primitive_with_secret_volumes( assert actual == job_request_payload def test_to_and_from_primitive_with_disk_volumes( - self, job_request_payload: Dict[str, Any] + self, job_request_payload: dict[str, Any] ) -> None: job_request_payload["container"]["disk_volumes"] = [ { diff --git a/tests/unit/test_job_policy_enforcer.py b/tests/unit/test_job_policy_enforcer.py index 6ccb13a8f..0117f43bb 100644 --- a/tests/unit/test_job_policy_enforcer.py +++ b/tests/unit/test_job_policy_enforcer.py @@ -1,19 +1,11 @@ import asyncio import datetime import logging -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import replace from decimal import Decimal -from typing import ( - Any, - AsyncContextManager, - AsyncIterator, - Awaitable, - Callable, - Iterable, - List, - Optional, -) +from typing import Any, Optional import pytest from neuro_admin_client import AdminClient, Balance, Quota @@ -56,7 +48,7 @@ _EnforcePollingRunner = Callable[ - [JobPolicyEnforcer], AsyncContextManager[JobPolicyEnforcePoller] + [JobPolicyEnforcer], AbstractAsyncContextManager[JobPolicyEnforcePoller] ] @@ -154,7 +146,9 @@ class TestJobPolicyEnforcePoller: @pytest.fixture async def run_enforce_polling( self, job_policy_enforcer_config: JobPolicyEnforcerConfig - ) -> Callable[[JobPolicyEnforcer], AsyncContextManager[JobPolicyEnforcePoller]]: + ) -> Callable[ + [JobPolicyEnforcer], AbstractAsyncContextManager[JobPolicyEnforcePoller] + ]: @asynccontextmanager async def _factory( enforcer: JobPolicyEnforcer, @@ -252,10 +246,10 @@ def make_jobs( self, jobs_service: JobsService, job_request_factory: Callable[[], JobRequest], - ) -> Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]]: + ) -> Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]]: async def _make_jobs( user: AuthUser, org_name: Optional[str], count: int - ) -> List[Job]: + ) -> list[Job]: return [ ( await jobs_service.create_job( @@ -298,7 +292,7 @@ async def test_user_credits_disabled_do_nothing( self, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_not_cancelled: Callable[[Iterable[Job]], Awaitable[None]], user_factory: UserFactory, test_cluster: str, @@ -316,7 +310,7 @@ async def test_user_has_credits_do_nothing( test_user: AuthUser, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_not_cancelled: Callable[[Iterable[Job]], Awaitable[None]], user_factory: UserFactory, test_cluster: str, @@ -337,7 +331,7 @@ async def test_user_has_no_credits_kill_all( test_user: AuthUser, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_cancelled: Callable[[Iterable[Job], str], Awaitable[None]], credits: Decimal, mock_admin_client: MockAdminClient, @@ -359,7 +353,7 @@ async def test_user_has_no_access_to_cluster_kill_all( test_user: AuthUser, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_cancelled: Callable[[Iterable[Job], str], Awaitable[None]], mock_admin_client: MockAdminClient, ) -> None: @@ -375,7 +369,7 @@ async def test_orgs_credits_disabled_do_nothing( self, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_not_cancelled: Callable[[Iterable[Job]], Awaitable[None]], org_factory: OrgFactory, user_factory: UserFactory, @@ -397,7 +391,7 @@ async def test_org_has_credits_do_nothing( test_user: AuthUser, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_not_cancelled: Callable[[Iterable[Job]], Awaitable[None]], org_factory: OrgFactory, user_factory: UserFactory, @@ -424,7 +418,7 @@ async def test_org_has_no_credits_kill_all( test_user_with_org: AuthUser, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_cancelled: Callable[[Iterable[Job], str], Awaitable[None]], credits: Decimal, mock_admin_client: MockAdminClient, @@ -447,7 +441,7 @@ async def test_org_has_no_access_to_cluster_kill_all( test_user_with_org: AuthUser, has_credits_enforcer: CreditsLimitEnforcer, mock_auth_client: MockAuthClient, - make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[List[Job]]], + make_jobs: Callable[[AuthUser, Optional[str], int], Awaitable[list[Job]]], check_cancelled: Callable[[Iterable[Job], str], Awaitable[None]], mock_admin_client: MockAdminClient, ) -> None: diff --git a/tests/unit/test_job_rest_validator.py b/tests/unit/test_job_rest_validator.py index d5f5fbc25..55e7bd6bd 100644 --- a/tests/unit/test_job_rest_validator.py +++ b/tests/unit/test_job_rest_validator.py @@ -1,5 +1,4 @@ from datetime import datetime, timezone -from typing import Tuple import pytest import trafaret as t @@ -149,7 +148,7 @@ def test_create_user_name_validator_none__ok(self) -> None: ("with123nums-and-dash", 1), ], ) - def test_user_name_validators__ok(self, pair: Tuple[str, int]) -> None: + def test_user_name_validators__ok(self, pair: tuple[str, int]) -> None: value = pair[0] * pair[1] validator = create_user_name_validator() assert validator.check(value) @@ -162,7 +161,7 @@ def test_user_name_validators__ok(self, pair: Tuple[str, int]) -> None: ("test/foo/bar", 1), ], ) - def test_role_name_validator__ok(self, pair: Tuple[str, int]) -> None: + def test_role_name_validator__ok(self, pair: tuple[str, int]) -> None: value = pair[0] * pair[1] validator = create_user_name_validator() assert validator.check(value) @@ -173,7 +172,7 @@ def test_role_name_validator__ok(self, pair: Tuple[str, int]) -> None: ("test/foo/bar", 1), ], ) - def test_base_owner_validator__fail(self, pair: Tuple[str, int]) -> None: + def test_base_owner_validator__fail(self, pair: tuple[str, int]) -> None: value = pair[0] * pair[1] validator = create_base_owner_name_validator() with pytest.raises(t.DataError): @@ -223,7 +222,7 @@ def test_base_owner_validator__fail(self, pair: Tuple[str, int]) -> None: ("46CAC3A6-2956-481B-B4AA-A80A6EAF2CDE", 1), # regression test ], ) - def test_user_name_validators__fail(self, pair: Tuple[str, int]) -> None: + def test_user_name_validators__fail(self, pair: tuple[str, int]) -> None: value = pair[0] * pair[1] validator = create_user_name_validator() with pytest.raises(t.DataError): diff --git a/tests/unit/test_job_service.py b/tests/unit/test_job_service.py index 45b0a5df1..1c5c21747 100644 --- a/tests/unit/test_job_service.py +++ b/tests/unit/test_job_service.py @@ -1,10 +1,11 @@ import asyncio import base64 import json +from collections.abc import AsyncIterator, Callable from dataclasses import replace from datetime import datetime, timezone from decimal import Decimal -from typing import Any, AsyncIterator, Callable +from typing import Any from unittest import mock import pytest diff --git a/tests/unit/test_jobs_poller.py b/tests/unit/test_jobs_poller.py index d6659dc1c..4f00dffa0 100644 --- a/tests/unit/test_jobs_poller.py +++ b/tests/unit/test_jobs_poller.py @@ -1,5 +1,6 @@ import asyncio -from typing import Any, AsyncIterator, Callable +from collections.abc import AsyncIterator, Callable +from typing import Any import pytest from neuro_auth_client import User as AuthUser diff --git a/tests/unit/test_jobs_poller_client.py b/tests/unit/test_jobs_poller_client.py index 2c75c023e..28fa0fa06 100644 --- a/tests/unit/test_jobs_poller_client.py +++ b/tests/unit/test_jobs_poller_client.py @@ -1,7 +1,8 @@ +from collections.abc import AsyncIterator, Mapping from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, AsyncIterator, Mapping, Optional +from typing import Any, Optional import aiohttp.web import pytest diff --git a/tests/unit/test_kube_orchestrator.py b/tests/unit/test_kube_orchestrator.py index b74652627..104b768de 100644 --- a/tests/unit/test_kube_orchestrator.py +++ b/tests/unit/test_kube_orchestrator.py @@ -1,5 +1,5 @@ from pathlib import PurePath -from typing import Any, Dict, List +from typing import Any from unittest import mock import pytest @@ -637,10 +637,10 @@ class TestJobStatusItemFactory: def test_status( self, phase: str, - container_statuses: List[Dict[str, Any]], + container_statuses: list[dict[str, Any]], expected_status: JobStatus, ) -> None: - payload: Dict[str, Any] = {"phase": phase} + payload: dict[str, Any] = {"phase": phase} if container_statuses: payload["containerStatuses"] = container_statuses pod_status = PodStatus.from_primitive(payload) @@ -1035,7 +1035,7 @@ def test_to_primitive_with_labels(self) -> None: class TestService: @pytest.fixture - def service_payload(self) -> Dict[str, Any]: + def service_payload(self) -> dict[str, Any]: return { "metadata": {"name": "testservice"}, "spec": { @@ -1047,8 +1047,8 @@ def service_payload(self) -> Dict[str, Any]: @pytest.fixture def service_payload_with_uid( - self, service_payload: Dict[str, Any] - ) -> Dict[str, Any]: + self, service_payload: dict[str, Any] + ) -> dict[str, Any]: return { **service_payload, "metadata": { @@ -1057,7 +1057,7 @@ def service_payload_with_uid( }, } - def test_to_primitive(self, service_payload: Dict[str, Dict[str, Any]]) -> None: + def test_to_primitive(self, service_payload: dict[str, dict[str, Any]]) -> None: service = Service( name="testservice", selector=service_payload["spec"]["selector"], @@ -1066,7 +1066,7 @@ def test_to_primitive(self, service_payload: Dict[str, Dict[str, Any]]) -> None: assert service.to_primitive() == service_payload def test_to_primitive_with_labels( - self, service_payload: Dict[str, Dict[str, Any]] + self, service_payload: dict[str, dict[str, Any]] ) -> None: labels = {"label-name": "label-value"} expected_payload = service_payload.copy() @@ -1080,7 +1080,7 @@ def test_to_primitive_with_labels( assert service.to_primitive() == expected_payload def test_to_primitive_load_balancer( - self, service_payload: Dict[str, Dict[str, Any]] + self, service_payload: dict[str, dict[str, Any]] ) -> None: service = Service( name="testservice", @@ -1092,7 +1092,7 @@ def test_to_primitive_load_balancer( assert service.to_primitive() == service_payload def test_to_primitive_headless( - self, service_payload: Dict[str, Dict[str, Any]] + self, service_payload: dict[str, dict[str, Any]] ) -> None: service = Service( name="testservice", @@ -1104,7 +1104,7 @@ def test_to_primitive_headless( assert service.to_primitive() == service_payload def test_from_primitive( - self, service_payload_with_uid: Dict[str, Dict[str, Any]] + self, service_payload_with_uid: dict[str, dict[str, Any]] ) -> None: service = Service.from_primitive(service_payload_with_uid) assert service == Service( @@ -1115,7 +1115,7 @@ def test_from_primitive( ) def test_from_primitive_with_labels( - self, service_payload_with_uid: Dict[str, Dict[str, Any]] + self, service_payload_with_uid: dict[str, dict[str, Any]] ) -> None: labels = {"label-name": "label-value"} input_payload = service_payload_with_uid.copy() @@ -1130,7 +1130,7 @@ def test_from_primitive_with_labels( ) def test_from_primitive_node_port( - self, service_payload_with_uid: Dict[str, Dict[str, Any]] + self, service_payload_with_uid: dict[str, dict[str, Any]] ) -> None: service_payload_with_uid["spec"]["type"] = "NodePort" service = Service.from_primitive(service_payload_with_uid) @@ -1143,7 +1143,7 @@ def test_from_primitive_node_port( ) def test_from_primitive_headless( - self, service_payload_with_uid: Dict[str, Dict[str, Any]] + self, service_payload_with_uid: dict[str, dict[str, Any]] ) -> None: service_payload_with_uid["spec"]["clusterIP"] = "None" service = Service.from_primitive(service_payload_with_uid) @@ -1168,7 +1168,7 @@ def test_create_headless_for_pod(self) -> None: class TestContainerStatus: def test_no_state(self) -> None: - payload: Dict[str, Any] = {"state": {}} + payload: dict[str, Any] = {"state": {}} status = ContainerStatus(payload) assert status.is_waiting assert status.reason is None diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 9a0e28c59..8b0958987 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -1,7 +1,8 @@ +from collections.abc import Sequence from datetime import datetime, timedelta from decimal import Decimal from pathlib import PurePath -from typing import Any, Dict, Sequence +from typing import Any from unittest import mock import pytest @@ -58,7 +59,7 @@ class TestContainerRequestValidator: @pytest.fixture - def payload(self) -> Dict[str, Any]: + def payload(self) -> dict[str, Any]: return { "image": "testimage", "resources": {"cpu": 0.1, "memory_mb": 16}, @@ -71,7 +72,7 @@ def payload(self) -> Dict[str, Any]: } @pytest.fixture - def payload_with_zero_gpu(self) -> Dict[str, Any]: + def payload_with_zero_gpu(self) -> dict[str, Any]: return { "image": "testimage", "resources": {"cpu": 0.1, "memory_mb": 16, "gpu": 0}, @@ -84,7 +85,7 @@ def payload_with_zero_gpu(self) -> Dict[str, Any]: } @pytest.fixture - def payload_with_negative_gpu(self) -> Dict[str, Any]: + def payload_with_negative_gpu(self) -> dict[str, Any]: return { "image": "testimage", "resources": {"cpu": 0.1, "memory_mb": 16, "gpu": -1}, @@ -97,7 +98,7 @@ def payload_with_negative_gpu(self) -> Dict[str, Any]: } @pytest.fixture - def payload_with_one_gpu(self) -> Dict[str, Any]: + def payload_with_one_gpu(self) -> dict[str, Any]: return { "image": "testimage", "resources": {"cpu": 0.1, "memory_mb": 16, "gpu": 1}, @@ -110,7 +111,7 @@ def payload_with_one_gpu(self) -> Dict[str, Any]: } @pytest.fixture - def payload_with_too_many_gpu(self) -> Dict[str, Any]: + def payload_with_too_many_gpu(self) -> dict[str, Any]: return { "image": "testimage", "resources": {"cpu": 0.1, "memory_mb": 16, "gpu": 130}, @@ -123,7 +124,7 @@ def payload_with_too_many_gpu(self) -> Dict[str, Any]: } @pytest.fixture - def payload_with_dev_shm(self) -> Dict[str, Any]: + def payload_with_dev_shm(self) -> dict[str, Any]: return { "image": "testimage", "resources": {"cpu": 0.1, "memory_mb": 16, "shm": True}, @@ -135,7 +136,7 @@ def payload_with_dev_shm(self) -> Dict[str, Any]: ], } - def test_allowed_volumes(self, payload: Dict[str, Any]) -> None: + def test_allowed_volumes(self, payload: dict[str, Any]) -> None: validator = create_container_request_validator( allow_volumes=True, cluster_name="test-cluster" ) @@ -144,7 +145,7 @@ def test_allowed_volumes(self, payload: Dict[str, Any]) -> None: assert "shm" not in result["resources"] def test_allowed_volumes_with_shm( - self, payload_with_dev_shm: Dict[str, Any] + self, payload_with_dev_shm: dict[str, Any] ) -> None: validator = create_container_request_validator( allow_volumes=True, cluster_name="test-cluster" @@ -153,19 +154,19 @@ def test_allowed_volumes_with_shm( assert result["volumes"][0]["read_only"] assert result["resources"]["shm"] - def test_disallowed_volumes(self, payload: Dict[str, Any]) -> None: + def test_disallowed_volumes(self, payload: dict[str, Any]) -> None: validator = create_container_request_validator(cluster_name="test-cluster") with pytest.raises(ValueError, match="volumes is not allowed key"): validator.check(payload) - def test_with_zero_gpu(self, payload_with_zero_gpu: Dict[str, Any]) -> None: + def test_with_zero_gpu(self, payload_with_zero_gpu: dict[str, Any]) -> None: validator = create_container_request_validator( allow_volumes=True, cluster_name="test-cluster" ) result = validator.check(payload_with_zero_gpu) assert result["resources"]["gpu"] == 0 - def test_with_one_gpu(self, payload_with_one_gpu: Dict[str, Any]) -> None: + def test_with_one_gpu(self, payload_with_one_gpu: dict[str, Any]) -> None: validator = create_container_request_validator( allow_volumes=True, cluster_name="test-cluster" ) @@ -173,14 +174,14 @@ def test_with_one_gpu(self, payload_with_one_gpu: Dict[str, Any]) -> None: assert result["resources"]["gpu"] assert result["resources"]["gpu"] == 1 - def test_with_too_many_gpu(self, payload_with_too_many_gpu: Dict[str, Any]) -> None: + def test_with_too_many_gpu(self, payload_with_too_many_gpu: dict[str, Any]) -> None: validator = create_container_request_validator( allow_volumes=True, cluster_name="test-cluster" ) with pytest.raises(ValueError, match="gpu"): validator.check(payload_with_too_many_gpu) - def test_with_negative_gpu(self, payload_with_negative_gpu: Dict[str, Any]) -> None: + def test_with_negative_gpu(self, payload_with_negative_gpu: dict[str, Any]) -> None: validator = create_container_request_validator( allow_volumes=True, cluster_name="test-cluster" ) @@ -289,7 +290,7 @@ def test_tpu(self) -> None: "software_version": "1.14", } - def test_with_entrypoint_and_cmd(self, payload: Dict[str, Any]) -> None: + def test_with_entrypoint_and_cmd(self, payload: dict[str, Any]) -> None: payload["entrypoint"] = "/script.sh" payload["command"] = "arg1 arg2 arg3" validator = create_container_request_validator( @@ -299,13 +300,13 @@ def test_with_entrypoint_and_cmd(self, payload: Dict[str, Any]) -> None: assert result["entrypoint"] == "/script.sh" assert result["command"] == "arg1 arg2 arg3" - def test_invalid_entrypoint(self, payload: Dict[str, Any]) -> None: + def test_invalid_entrypoint(self, payload: dict[str, Any]) -> None: payload["entrypoint"] = '"' validator = create_container_request_validator(cluster_name="test-cluster") with pytest.raises(DataError, match="invalid command format"): validator.check(payload) - def test_invalid_command(self, payload: Dict[str, Any]) -> None: + def test_invalid_command(self, payload: dict[str, Any]) -> None: payload["command"] = '"' validator = create_container_request_validator(cluster_name="test-cluster") with pytest.raises(DataError, match="invalid command format"): @@ -461,7 +462,7 @@ def test_validator(self) -> None: } def test_validator_default_preset(self) -> None: - request: Dict[str, Any] = {"container": {}} + request: dict[str, Any] = {"container": {}} validator = create_job_preset_validator( [ Preset( @@ -1048,7 +1049,7 @@ def test_create_from_query_fail(self, query: Any) -> None: factory(MultiDict(query)) # type: ignore -def make_access_tree(perm_dict: Dict[str, str]) -> ClientSubTreeViewRoot: +def make_access_tree(perm_dict: dict[str, str]) -> ClientSubTreeViewRoot: tree = ClientSubTreeViewRoot( scheme="job", path="/",