From e3dda39ac32a602b5c648054a8f9add6464b4db8 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Fri, 5 Jan 2024 09:31:56 +0100 Subject: [PATCH] feat: Python API update taskhandler and worker --- .../python/src/armonik/worker/__init__.py | 6 + .../python/src/armonik/worker/taskhandler.py | 103 +++++++++++++++-- packages/python/tests/conftest.py | 7 +- packages/python/tests/test_taskhandler.py | 108 ++++++++++++++++++ packages/python/tests/test_worker.py | 80 +++++++++++++ 5 files changed, 291 insertions(+), 13 deletions(-) create mode 100644 packages/python/tests/test_taskhandler.py create mode 100644 packages/python/tests/test_worker.py diff --git a/packages/python/src/armonik/worker/__init__.py b/packages/python/src/armonik/worker/__init__.py index 508d49ae5..78a61174c 100644 --- a/packages/python/src/armonik/worker/__init__.py +++ b/packages/python/src/armonik/worker/__init__.py @@ -1,3 +1,9 @@ from .worker import ArmoniKWorker from .taskhandler import TaskHandler from .seqlogger import ClefLogger + +__all__ = [ + 'ArmoniKWorker', + 'TaskHandler', + 'ClefLogger', +] diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index 49eeb8ff3..7d18ef7db 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,12 +1,14 @@ from __future__ import annotations import os +from deprecation import deprecated from typing import Optional, Dict, List, Tuple, Union, cast -from ..common import TaskOptions, TaskDefinition, Task -from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest +from ..common import TaskOptions, TaskDefinition, Task, Result +from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest, CreateResultsRequest, CreateResultsResponse, SubmitTasksRequest, SubmitTasksResponse from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.common.worker_common_pb2 import ProcessRequest +from ..common.helpers import batched class TaskHandler: @@ -31,6 +33,7 @@ def __init__(self, request: ProcessRequest, agent_client: AgentStub): with open(os.path.join(self.data_folder, dd), "rb") as f: self.data_dependencies[dd] = f.read() + @deprecated(deprecated_in="3.15.0", details="Use submit_tasks and instead and create the payload using create_result_metadata and send_result") def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]: """Create new tasks for ArmoniK @@ -67,21 +70,99 @@ def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskO raise Exception("Unknown value") return tasks_created, tasks_creation_failed - def send_result(self, key: str, data: Union[bytes, bytearray]) -> None: - """ Send task result + def submit_tasks(self, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, batch_size: Optional[int] = 100) -> None: + """Submit tasks to the agent. Args: - key: Result key - data: Result data + tasks: List of task definitions + default_task_options: Default Task Options used if a task has its options not set + batch_size: Batch size for submission """ - with open(os.path.join(self.data_folder, key), "wb") as f: - f.write(data) + for tasks_batch in batched(tasks, batch_size): + task_creations = [] + + for t in tasks_batch: + task_creation = SubmitTasksRequest.TaskCreation( + expected_output_keys=t.expected_output_ids, + payload_id=t.payload_id, + data_dependencies=t.data_dependencies + ) + if t.options: + task_creation.task_options=t.options.to_message() + task_creations.append(task_creation) + + request = SubmitTasksRequest( + session_id=self.session_id, + communication_token=self.token, + task_creations=task_creations + ) + + if default_task_options: + request.task_options=default_task_options.to_message(), - self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token)) + self._client.SubmitTasks(request) - def get_results_ids(self, names: List[str]) -> Dict[str, str]: - return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results} + def send_results(self, results_data: Dict[str, bytes | bytearray]) -> None: + """Send results. + Args: + result_data: A dictionnary mapping each result ID to its data. + """ + for result_id, result_data in results_data.items(): + with open(os.path.join(self.data_folder, result_id), "wb") as f: + f.write(result_data) + + request = NotifyResultDataRequest( + ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=result_id) for result_id in results_data.keys()], + communication_token=self.token + ) + self._client.NotifyResultData(request) + + def create_results_metadata(self, result_names: List[str], batch_size: int = 100) -> Dict[str, List[Result]]: + """ + Create the metadata of multiple results at once. + Data have to be uploaded separately. + + Args: + result_names: The names of the results to create. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping each result name to its result summary. + """ + results = {} + for result_names_batch in batched(result_names, batch_size): + request = CreateResultsMetaDataRequest( + results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names], + session_id=self.session_id, + communication_token=self.token + ) + response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) + for result_message in response.results: + results[result_message.name] = Result.from_message(result_message) + return results + + def create_results(self, results_data: Dict[str, bytes], batch_size: int = 1) -> Dict[str, Result]: + """Create one result with data included in the request. + + Args: + results_data: A dictionnary mapping the result names to their actual data. + batch_size: Batch size for querying. + + Return: + A dictionnary mappin each result name to its corresponding result summary. + """ + results = {} + for results_ids_batch in batched(results_data.keys(), batch_size): + request = CreateResultsRequest( + results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_ids_batch], + session_id=self.session_id, + communication_token=self.token + ) + response: CreateResultsResponse = self._client.CreateResults(request) + for message in response.results: + results[message.name] = Result.from_message(message) + return results def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size): req = CreateTaskRequest( diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index 0e3a46067..c5eaada26 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -44,6 +44,7 @@ def clean_up(request): # Remove the temporary files created for testing os.remove(os.path.join(data_folder, "payload-id")) os.remove(os.path.join(data_folder, "dd-id")) + os.remove(os.path.join(data_folder, "result-id")) # Reset the mock server counters try: @@ -54,7 +55,7 @@ def clean_up(request): print("An error occurred when resetting the server: " + str(e)) -def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]: +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[AgentStub, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]: """ Get the ArmoniK client instance based on the specified service name. @@ -63,7 +64,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint. Returns: - Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions] + Union[AgentStub, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions] An instance of the specified ArmoniK client. Raises: @@ -75,6 +76,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK """ channel = grpc.insecure_channel(endpoint).__enter__() match client_name: + case "Agent": + return AgentStub(channel) case "Partitions": return ArmoniKPartitions(channel) case "Results": diff --git a/packages/python/tests/test_taskhandler.py b/packages/python/tests/test_taskhandler.py new file mode 100644 index 000000000..f5e8634eb --- /dev/null +++ b/packages/python/tests/test_taskhandler.py @@ -0,0 +1,108 @@ +import datetime +import logging +import warnings + +from .conftest import all_rpc_called, rpc_called, get_client, data_folder +from armonik.common import TaskDefinition, TaskOptions +from armonik.worker import TaskHandler +from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub +from armonik.protogen.common.worker_common_pb2 import ProcessRequest +from armonik.protogen.common.objects_pb2 import Configuration + + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + + +class TestTaskHandler: + + request =ProcessRequest( + communication_token="token", + session_id="session-id", + task_id="task-id", + expected_output_keys=["result-id"], + payload_id="payload-id", + data_dependencies=["dd-id"], + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=8000), + task_options=TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ).to_message() + ) + + def test_taskhandler_init(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + + assert task_handler.session_id == "session-id" + assert task_handler.task_id == "task-id" + assert task_handler.task_options == TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1, + partition_id='', + application_name='', + application_version='', + application_namespace='', + application_service='', + engine_type='', + options={} + ) + assert task_handler.token == "token" + assert task_handler.expected_results == ["result-id"] + assert task_handler.configuration == Configuration(data_chunk_max_size=8000) + assert task_handler.payload_id == "payload-id" + assert task_handler.data_folder == data_folder + assert task_handler.payload == "payload".encode() + assert task_handler.data_dependencies == {"dd-id": "dd".encode()} + + def test_create_task(self): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + task_handler = TaskHandler(self.request, get_client("Agent")) + tasks, errors = task_handler.create_tasks([TaskDefinition( + payload=b"payload", + expected_output_ids=["result-id"], + data_dependencies=[])]) + + assert issubclass(w[-1].category, DeprecationWarning) + assert rpc_called("Agent", "CreateTask") + assert tasks == [] + assert errors == [] + + def test_submit_tasks(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + tasks = task_handler.submit_tasks([TaskDefinition(payload_id="payload-id", + expected_output_ids=["result-id"], + data_dependencies=[])] + ) + + assert rpc_called("Agent", "SubmitTasks") + assert tasks is None + + def test_send_results(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + resuls = task_handler.send_results({"result-id": b"result data"}) + assert rpc_called("Agent", "NotifyResultData") + assert resuls is None + + def test_create_result_metadata(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + results = task_handler.create_results_metadata(["result-name"]) + + assert rpc_called("Agent", "CreateResultsMetaData") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results == {} + + def test_create_results(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + results = task_handler.create_results({"result-name": b"test data"}) + + assert rpc_called("Agent", "CreateResults") + assert results == {} + + def test_service_fully_implemented(self): + assert all_rpc_called("Agent", missings=["GetCommonData", "GetDirectData", "GetResourceData"]) diff --git a/packages/python/tests/test_worker.py b/packages/python/tests/test_worker.py new file mode 100644 index 000000000..42198bcdf --- /dev/null +++ b/packages/python/tests/test_worker.py @@ -0,0 +1,80 @@ +import datetime +import grpc +import logging +import os +import pytest + +from .conftest import data_folder, grpc_endpoint +from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger +from armonik.common import Output, TaskOptions +from armonik.protogen.common.objects_pb2 import Empty, Configuration +from armonik.protogen.common.worker_common_pb2 import ProcessRequest + + +def do_nothing(_: TaskHandler) -> Output: + return Output() + + +def throw_error(_: TaskHandler) -> Output: + raise ValueError("TestError") + + +def return_error(_: TaskHandler) -> Output: + return Output("TestError") + + +def return_and_send(th: TaskHandler) -> Output: + th.send_results({th.expected_results[0]: b"result"}) + return Output() + + +class TestWorker: + + request = ProcessRequest( + communication_token="token", + session_id="session-id", + task_id="task-id", + expected_output_keys=["result-id"], + payload_id="payload-id", + data_dependencies=["dd-id"], + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=8000), + task_options=TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ).to_message() + ) + + def test_do_nothing(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success + worker.HealthCheck(Empty(), None) + + def test_should_return_none(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + assert reply is None + + def test_should_error(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert not output.success + assert output.error == "TestError" + + def test_should_write_result(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) + reply = worker.Process(self.request, None) + assert reply is not None + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert output.success + assert os.path.exists(os.path.join(data_folder, self.request.expected_output_keys[0])) + with open(os.path.join(data_folder, self.request.expected_output_keys[0]), "rb") as f: + value = f.read() + assert len(value) > 0