From 888fa21302ede50b2eb2c26b55ca0311b879daf8 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Thu, 4 Jan 2024 13:07:21 +0100 Subject: [PATCH] feat: Python API update sessions and partitions services --- .../python/src/armonik/client/__init__.py | 2 + .../python/src/armonik/client/partitions.py | 65 +++++++++++++++++ .../python/src/armonik/client/sessions.py | 33 ++++++--- .../python/src/armonik/common/__init__.py | 4 +- .../python/src/armonik/common/enumwrapper.py | 22 ++++++ packages/python/src/armonik/common/objects.py | 22 ++++++ packages/python/tests/conftest.py | 10 ++- packages/python/tests/test_partitions.py | 41 +++++++++++ packages/python/tests/test_sessions.py | 72 +++++++++++++++++++ 9 files changed, 255 insertions(+), 16 deletions(-) create mode 100644 packages/python/src/armonik/client/partitions.py create mode 100644 packages/python/tests/test_partitions.py create mode 100644 packages/python/tests/test_sessions.py diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index a4a7cd74e..398b36ca2 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -1,3 +1,5 @@ +from .partitions import ArmoniKPartitions, PartitionFieldFilter +from .sessions import ArmoniKSessions, SessionFieldFilter from .submitter import ArmoniKSubmitter from .tasks import ArmoniKTasks, TaskFieldFilter from .results import ArmoniKResults, ResultFieldFilter diff --git a/packages/python/src/armonik/client/partitions.py b/packages/python/src/armonik/client/partitions.py new file mode 100644 index 000000000..9c72495cc --- /dev/null +++ b/packages/python/src/armonik/client/partitions.py @@ -0,0 +1,65 @@ +from typing import cast, List, Tuple + +from grpc import Channel + +from ..common import Direction, Partition +from ..common.filter import Filter, NumberFilter +from ..protogen.client.partitions_service_pb2_grpc import PartitionsStub +from ..protogen.common.partitions_common_pb2 import ListPartitionsRequest, ListPartitionsResponse, GetPartitionRequest, GetPartitionResponse +from ..protogen.common.partitions_fields_pb2 import PartitionField, PartitionRawField, PARTITION_RAW_ENUM_FIELD_PRIORITY +from ..protogen.common.partitions_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFiltersAnd, FilterField as rawFilterField +from ..protogen.common.sort_direction_pb2 import SortDirection + + +class PartitionFieldFilter: + PRIORITY = NumberFilter( + PartitionField(partition_raw_field=PartitionRawField(field=PARTITION_RAW_ENUM_FIELD_PRIORITY)), + rawFilters, + rawFiltersAnd, + rawFilterField + ) + + +class ArmoniKPartitions: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = PartitionsStub(grpc_channel) + + def list_partitions(self, partition_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = PartitionFieldFilter.PRIORITY, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Partition]]: + """List partitions based on a filter. + + Args: + partition_filter: Filter to apply when listing partitions + page: page number to request, useful for pagination, defaults to 0 + page_size: size of a page, defaults to 1000 + sort_field: field to sort the resulting list by, defaults to the status + sort_direction: direction of the sort, defaults to ascending + + Returns: + A tuple containing : + - The total number of results for the given filter + - The obtained list of results + """ + request = ListPartitionsRequest( + page=page, + page_size=page_size, + filters=cast(rawFilters, partition_filter.to_disjunction().to_message()) if partition_filter else None, + sort=ListPartitionsRequest.Sort(field=cast(PartitionField, sort_field.field), direction=sort_direction), + ) + response: ListPartitionsResponse = self._client.ListPartitions(request) + return response.total, [Partition.from_message(p) for p in response.partitions] + + def get_partition(self, partition_id: str) -> Partition: + """Get a partition by its ID. + + Args: + partition_id: The partition ID. + + Return: + The partition summary. + """ + return Partition.from_message(self._client.GetPartition(GetPartitionRequest(id=partition_id)).partition) diff --git a/packages/python/src/armonik/client/sessions.py b/packages/python/src/armonik/client/sessions.py index 8f676144d..84c96abfd 100644 --- a/packages/python/src/armonik/client/sessions.py +++ b/packages/python/src/armonik/client/sessions.py @@ -53,14 +53,26 @@ def create_session(self, default_task_options: TaskOptions, partition_ids: Optio Returns: Session Id """ - if partition_ids is None: - partition_ids = [] - request = CreateSessionRequest(default_task_option=default_task_options.to_message()) - for partition in partition_ids: - request.partition_ids.append(partition) + request = CreateSessionRequest( + default_task_option=default_task_options.to_message(), + partition_ids=partition_ids if partition_ids else [] + ) return self._client.CreateSession(request).session_id - def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: + def get_session(self, session_id: str): + """Get a session by its ID. + + Args: + session_id: The ID of the session. + + Return: + The session summary. + """ + request = GetSessionRequest(session_id=session_id) + response: GetSessionResponse = self._client.GetSession(request) + return Session.from_message(response.session) + + def list_sessions(self, session_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: """ List sessions @@ -76,14 +88,14 @@ def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 100 - The total number of sessions for the given filter - The obtained list of sessions """ - request : ListSessionsRequest = ListSessionsRequest( + request = ListSessionsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, task_filter.to_disjunction().to_message()), + filters=cast(rawFilters, session_filter.to_disjunction().to_message()) if session_filter else None, sort=ListSessionsRequest.Sort(field=cast(SessionField, sort_field.field), direction=sort_direction), ) - list_response : ListSessionsResponse = self._client.ListSessions(request) - return list_response.total, [Session.from_message(t) for t in list_response.sessions] + response : ListSessionsResponse = self._client.ListSessions(request) + return response.total, [Session.from_message(s) for s in response.sessions] def cancel_session(self, session_id: str) -> None: """Cancel a session @@ -92,4 +104,3 @@ def cancel_session(self, session_id: str) -> None: session_id: Id of the session to b cancelled """ self._client.CancelSession(CancelSessionRequest(session_id=session_id)) - \ No newline at end of file diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index 5d44f4c9f..5901a7262 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,4 +1,4 @@ from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter -from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result -from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus +from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition +from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, ResultStatus, SessionStatus from .filter import StringFilter, StatusFilter diff --git a/packages/python/src/armonik/common/enumwrapper.py b/packages/python/src/armonik/common/enumwrapper.py index 9c19a9a82..d6fe134cb 100644 --- a/packages/python/src/armonik/common/enumwrapper.py +++ b/packages/python/src/armonik/common/enumwrapper.py @@ -1,8 +1,10 @@ from __future__ import annotations from ..protogen.common.task_status_pb2 import TaskStatus as RawStatus, _TASKSTATUS, TASK_STATUS_CANCELLED, TASK_STATUS_CANCELLING, TASK_STATUS_COMPLETED, TASK_STATUS_CREATING, TASK_STATUS_DISPATCHED, TASK_STATUS_ERROR, TASK_STATUS_PROCESSED, TASK_STATUS_PROCESSING, TASK_STATUS_SUBMITTED, TASK_STATUS_TIMEOUT, TASK_STATUS_UNSPECIFIED, TASK_STATUS_RETRIED +from ..protogen.common.events_common_pb2 import EventsEnum as rawEventsEnum, EVENTS_ENUM_UNSPECIFIED, EVENTS_ENUM_NEW_TASK, EVENTS_ENUM_TASK_STATUS_UPDATE, EVENTS_ENUM_NEW_RESULT, EVENTS_ENUM_RESULT_STATUS_UPDATE, EVENTS_ENUM_RESULT_OWNER_UPDATE from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus, _SESSIONSTATUS, SESSION_STATUS_UNSPECIFIED, SESSION_STATUS_CANCELLED, SESSION_STATUS_RUNNING from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus, _RESULTSTATUS, RESULT_STATUS_UNSPECIFIED, RESULT_STATUS_CREATED, RESULT_STATUS_COMPLETED, RESULT_STATUS_ABORTED, RESULT_STATUS_NOTFOUND +from ..protogen.common.health_checks_common_pb2 import HEALTH_STATUS_ENUM_UNSPECIFIED, HEALTH_STATUS_ENUM_HEALTHY, HEALTH_STATUS_ENUM_DEGRADED, HEALTH_STATUS_ENUM_UNHEALTHY from ..protogen.common.worker_common_pb2 import HealthCheckReply from ..protogen.common.sort_direction_pb2 import SORT_DIRECTION_ASC, SORT_DIRECTION_DESC @@ -58,3 +60,23 @@ def name_from_value(status: RawResultStatus) -> str: COMPLETED = RESULT_STATUS_COMPLETED ABORTED = RESULT_STATUS_ABORTED NOTFOUND = RESULT_STATUS_NOTFOUND + + +class EventTypes: + UNSPECIFIED = EVENTS_ENUM_UNSPECIFIED + NEW_TASK = EVENTS_ENUM_NEW_TASK + TASK_STATUS_UPDATE = EVENTS_ENUM_TASK_STATUS_UPDATE + NEW_RESULT = EVENTS_ENUM_NEW_RESULT + RESULT_STATUS_UPDATE = EVENTS_ENUM_RESULT_STATUS_UPDATE + RESULT_OWNER_UPDATE = EVENTS_ENUM_RESULT_OWNER_UPDATE + + @classmethod + def from_string(cls, name: str): + return getattr(cls, name.upper()) + + +class ServiceHealthCheckStatus: + UNSPECIFIED = HEALTH_STATUS_ENUM_UNSPECIFIED + HEALTHY = HEALTH_STATUS_ENUM_HEALTHY + DEGRADED = HEALTH_STATUS_ENUM_DEGRADED + UNHEALTHY = HEALTH_STATUS_ENUM_UNHEALTHY diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 340821b8f..1b5801f7a 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -215,3 +215,25 @@ def from_message(cls, result_raw: ResultRaw) -> "Result": result_id=result_raw.result_id, size=result_raw.size ) + +@dataclass +class Partition: + id: str + parent_partition_ids: List[str] + pod_reserved: int + pod_max: int + pod_configuration: Dict[str, str] + preemption_percentage: int + priority: int + + @classmethod + def from_message(cls, partition_raw: PartitionRaw) -> "Partition": + return cls( + id=partition_raw.id, + parent_partition_ids=partition_raw.parent_partition_ids, + pod_reserved=partition_raw.pod_reserved, + pod_max=partition_raw.pod_max, + pod_configuration=partition_raw.pod_configuration, + preemption_percentage=partition_raw.preemption_percentage, + priority=partition_raw.priority + ) diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index b60a02483..0e3a46067 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -3,7 +3,7 @@ import pytest import requests -from armonik.client import ArmoniKResults, ArmoniKTasks, ArmoniKVersions +from armonik.client import ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub from typing import List, Union @@ -54,7 +54,7 @@ def clean_up(request): print("An error occurred when resetting the server: " + str(e)) -def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKTasks, ArmoniKVersions]: +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]: """ Get the ArmoniK client instance based on the specified service name. @@ -63,7 +63,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint. Returns: - Union[ArmoniKTasks, ArmoniKVersions] + Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions] An instance of the specified ArmoniK client. Raises: @@ -75,8 +75,12 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK """ channel = grpc.insecure_channel(endpoint).__enter__() match client_name: + case "Partitions": + return ArmoniKPartitions(channel) case "Results": return ArmoniKResults(channel) + case "Sessions": + return ArmoniKSessions(channel) case "Tasks": return ArmoniKTasks(channel) case "Versions": diff --git a/packages/python/tests/test_partitions.py b/packages/python/tests/test_partitions.py new file mode 100644 index 000000000..d46bd3063 --- /dev/null +++ b/packages/python/tests/test_partitions.py @@ -0,0 +1,41 @@ +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKPartitions, PartitionFieldFilter +from armonik.common import Partition + + +class TestArmoniKPartitions: + + def test_get_partitions(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + partition = partitions_client.get_partition("partition-id") + + assert rpc_called("Partitions", "GetPartition") + assert isinstance(partition, Partition) + assert partition.id == 'partition-id' + assert partition.parent_partition_ids == [] + assert partition.pod_reserved == 1 + assert partition.pod_max == 1 + assert partition.pod_configuration == {} + assert partition.preemption_percentage == 0 + assert partition.priority == 1 + + def test_list_partitions_no_filter(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + num, partitions = partitions_client.list_partitions() + + assert rpc_called("Partitions", "GetPartition") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert partitions == [] + + def test_list_partitions_with_filter(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + num, partitions = partitions_client.list_partitions(PartitionFieldFilter.PRIORITY == 1) + + assert rpc_called("Partitions", "GetPartition", 2) + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert partitions == [] + + def test_service_fully_implemented(self): + assert all_rpc_called("Partitions") diff --git a/packages/python/tests/test_sessions.py b/packages/python/tests/test_sessions.py new file mode 100644 index 000000000..c4c0173a4 --- /dev/null +++ b/packages/python/tests/test_sessions.py @@ -0,0 +1,72 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKSessions, SessionFieldFilter +from armonik.common import Session, SessionStatus, TaskOptions + + +class TestArmoniKSessions: + + def test_create_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + default_task_options = TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ) + session_id = sessions_client.create_session(default_task_options) + + assert rpc_called("Sessions", "CreateSession") + assert session_id == "session-id" + + def test_get_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + session = sessions_client.get_session("session-id") + + assert rpc_called("Sessions", "GetSession") + assert isinstance(session, Session) + assert session.session_id == 'session-id' + assert session.status == SessionStatus.CANCELLED + assert session.partition_ids == [] + assert session.options == TaskOptions( + max_duration=datetime.timedelta(0), + priority=0, + max_retries=0, + partition_id='', + application_name='', + application_version='', + application_namespace='', + application_service='', + engine_type='', + options={} + ) + assert session.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert session.cancelled_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert session.duration == datetime.timedelta(0) + + def test_list_session_no_filter(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + num, sessions = sessions_client.list_sessions() + + assert rpc_called("Sessions", "ListSessions") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert sessions == [] + + def test_list_session_with_filter(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + num, sessions = sessions_client.list_sessions(SessionFieldFilter.STATUS == SessionStatus.RUNNING) + + assert rpc_called("Sessions", "ListSessions", 2) + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert sessions == [] + + def test_cancel_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + sessions_client.cancel_session("session-id") + + assert rpc_called("Sessions", "CancelSession") + + def test_service_fully_implemented(self): + assert all_rpc_called("Sessions")