From 9cfc4b27aa518d4e7732790a1a4b28b4d74eaca8 Mon Sep 17 00:00:00 2001 From: Dylan Brasseur Date: Fri, 23 Aug 2024 09:59:34 +0200 Subject: [PATCH] feat: Removing deprecated task handler function fix: Fixed result parsing in task handler --- packages/python/src/armonik/common/objects.py | 11 ++ .../python/src/armonik/worker/taskhandler.py | 167 +++--------------- packages/python/tests/test_taskhandler.py | 29 +-- 3 files changed, 40 insertions(+), 167 deletions(-) diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 3eacb0eca..db89ac00b 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -6,6 +6,7 @@ from deprecation import deprecated +from ..protogen.common.agent_common_pb2 import ResultMetaData from ..protogen.common.applications_common_pb2 import ApplicationRaw from ..protogen.common.tasks_common_pb2 import TaskDetailed from .filter import ( @@ -455,6 +456,16 @@ def from_message(cls, result_raw: ResultRaw) -> "Result": size=result_raw.size, ) + @classmethod + def from_result_metadata(cls, result_metadata: ResultMetaData) -> "Result": + return cls( + session_id=result_metadata.session_id, + name=result_metadata.name, + status=result_metadata.status, + created_at=timestamp_to_datetime(result_metadata.created_at), + result_id=result_metadata.result_id, + ) + def __eq__(self, other: "Result") -> bool: return ( self.session_id == other.session_id diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index 508150670..4fa518f16 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,11 +1,11 @@ from __future__ import annotations + import os -from deprecation import deprecated -from typing import Optional, Dict, List, Tuple, Union +from typing import Optional, Dict, List, Union -from ..common import TaskOptions, TaskDefinition, Task, Result +from ..common import TaskOptions, TaskDefinition, Result, Task +from ..common.helpers import batched from ..protogen.common.agent_common_pb2 import ( - CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest, @@ -14,15 +14,10 @@ SubmitTasksRequest, ) from ..protogen.common.objects_pb2 import ( - TaskRequest, - DataChunk, - InitTaskRequest, - TaskRequestHeader, Configuration, ) -from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.common.worker_common_pb2 import ProcessRequest -from ..common.helpers import batched +from ..protogen.worker.agent_service_pb2_grpc import AgentStub class TaskHandler: @@ -47,72 +42,12 @@ def __init__(self, request: ProcessRequest, agent_client: AgentStub): with open(os.path.join(self.data_folder, dd), "rb") as f: self.data_dependencies[dd] = f.read() - @deprecated( - deprecated_in="3.15.0", - details="Use submit_tasks and instead and create the payload using create_result_metadata and send_result", - ) - def create_tasks( - self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None - ) -> Tuple[List[Task], List[str]]: - """Create new tasks for ArmoniK - - Args: - tasks: List of task definitions - task_options: Task Options used for this batch of tasks - - Returns: - Tuple containing the list of successfully sent tasks, and - the list of submission errors if any - """ - task_requests = [] - - for t in tasks: - task_request = TaskRequest() - task_request.expected_output_keys.extend(t.expected_output_ids) - task_request.data_dependencies.extend(t.data_dependencies) - task_request.payload = t.payload - task_requests.append(task_request) - assert self.configuration is not None - create_tasks_reply = self._client.CreateTask( - _to_request_stream( - task_requests, - self.token, - task_options.to_message() if task_options is not None else None, - self.configuration.data_chunk_max_size, - ) - ) - ret = create_tasks_reply.WhichOneof("Response") - if ret is None or ret == "error": - raise Exception(f"Issue with server when submitting tasks : {create_tasks_reply.error}") - elif ret == "creation_status_list": - tasks_created = [] - tasks_creation_failed = [] - for creation_status in create_tasks_reply.creation_status_list.creation_statuses: - if creation_status.WhichOneof("Status") == "task_info": - tasks_created.append( - Task( - id=creation_status.task_info.task_id, - session_id=self.session_id, - expected_output_ids=[ - k for k in creation_status.task_info.expected_output_keys - ], - data_dependencies=[ - k for k in creation_status.task_info.data_dependencies - ], - ) - ) - else: - tasks_creation_failed.append(creation_status.error) - else: - raise Exception("Unknown value") - return tasks_created, tasks_creation_failed - def submit_tasks( self, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, batch_size: Optional[int] = 100, - ) -> None: + ) -> List[Task]: """Submit tasks to the agent. Args: @@ -120,6 +55,7 @@ def submit_tasks( default_task_options: Default Task Options used if a task has its options not set batch_size: Batch size for submission """ + submitted_tasks: List[Task] = [] for tasks_batch in batched(tasks, batch_size): task_creations = [] @@ -142,13 +78,23 @@ def submit_tasks( if default_task_options: request.task_options = (default_task_options.to_message(),) - self._client.SubmitTasks(request) + submitted_tasks.extend( + Task( + id=t.task_id, + expected_output_ids=list(t.expected_output_ids), + data_dependencies=list(t.data_dependencies), + session_id=self.session_id, + payload_id=self.payload_id, + ) + for t in self._client.SubmitTasks(request).task_infos + ) + return submitted_tasks def send_results(self, results_data: Dict[str, Union[bytes, bytearray]]) -> None: """Send results. Args: - result_data: A dictionnary mapping each result ID to its data. + results_data: A dictionary mapping each result ID to its data. """ for result_id, result_data in results_data.items(): with open(os.path.join(self.data_folder, result_id), "wb") as f: @@ -167,7 +113,7 @@ def send_results(self, results_data: Dict[str, Union[bytes, bytearray]]) -> None def create_results_metadata( self, result_names: List[str], batch_size: int = 100 - ) -> Dict[str, List[Result]]: + ) -> Dict[str, Result]: """ Create the metadata of multiple results at once. Data have to be uploaded separately. @@ -177,21 +123,21 @@ def create_results_metadata( batch_size: Batch size for querying. Return: - A dictionnary mapping each result name to its result summary. + A dictionary mapping each result name to its result summary. """ results = {} for result_names_batch in batched(result_names, batch_size): request = CreateResultsMetaDataRequest( results=[ CreateResultsMetaDataRequest.ResultCreate(name=result_name) - for result_name in result_names + for result_name in result_names_batch ], session_id=self.session_id, communication_token=self.token, ) response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) for result_message in response.results: - results[result_message.name] = Result.from_message(result_message) + results[result_message.name] = Result.from_result_metadata(result_message) return results def create_results( @@ -200,11 +146,11 @@ def create_results( """Create one result with data included in the request. Args: - results_data: A dictionnary mapping the result names to their actual data. + results_data: A dictionary mapping the result names to their actual data. batch_size: Batch size for querying. Return: - A dictionnary mappin each result name to its corresponding result summary. + A dictionary mapping each result name to its corresponding result summary. """ results = {} for results_ids_batch in batched(results_data.keys(), batch_size): @@ -218,66 +164,5 @@ def create_results( ) response: CreateResultsResponse = self._client.CreateResults(request) for message in response.results: - results[message.name] = Result.from_message(message) + results[message.name] = Result.from_result_metadata(message) return results - - -def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size): - req = CreateTaskRequest( - init_task=InitTaskRequest( - header=TaskRequestHeader( - data_dependencies=request.data_dependencies, - expected_output_keys=request.expected_output_keys, - ) - ), - communication_token=communication_token, - ) - yield req - start = 0 - payload_length = len(request.payload) - if payload_length == 0: - req = CreateTaskRequest( - task_payload=DataChunk(data=b""), communication_token=communication_token - ) - yield req - while start < payload_length: - chunk_size = min(chunk_max_size, payload_length - start) - req = CreateTaskRequest( - task_payload=DataChunk(data=request.payload[start : start + chunk_size]), - communication_token=communication_token, - ) - yield req - start += chunk_size - req = CreateTaskRequest( - task_payload=DataChunk(data_complete=True), - communication_token=communication_token, - ) - yield req - - if is_last: - req = CreateTaskRequest( - init_task=InitTaskRequest(last_task=True), - communication_token=communication_token, - ) - yield req - - -def _to_request_stream(requests, communication_token, t_options, chunk_max_size): - if t_options is None: - req = CreateTaskRequest( - init_request=CreateTaskRequest.InitRequest(), - communication_token=communication_token, - ) - else: - req = CreateTaskRequest( - init_request=CreateTaskRequest.InitRequest(task_options=t_options), - communication_token=communication_token, - ) - yield req - if len(requests) == 0: - return - for r in requests[:-1]: - for req in _to_request_stream_internal(r, communication_token, False, chunk_max_size): - yield req - for req in _to_request_stream_internal(requests[-1], communication_token, True, chunk_max_size): - yield req diff --git a/packages/python/tests/test_taskhandler.py b/packages/python/tests/test_taskhandler.py index ce00d9add..bf602b530 100644 --- a/packages/python/tests/test_taskhandler.py +++ b/packages/python/tests/test_taskhandler.py @@ -1,6 +1,5 @@ import datetime import logging -import warnings from .conftest import all_rpc_called, rpc_called, get_client, data_folder from armonik.common import TaskDefinition, TaskOptions @@ -53,27 +52,6 @@ def test_taskhandler_init(self): assert task_handler.payload == "payload".encode() assert task_handler.data_dependencies == {"dd-id": "dd".encode()} - def test_create_task(self): - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - - task_handler = TaskHandler(self.request, get_client("Agent")) - tasks, errors = task_handler.create_tasks( - [ - TaskDefinition( - payload=b"payload", - expected_output_ids=["result-id"], - data_dependencies=[], - ) - ] - ) - - assert issubclass(w[-1].category, DeprecationWarning) - assert rpc_called("Agent", "CreateTask") - assert tasks == [] - assert errors == [] - def test_submit_tasks(self): task_handler = TaskHandler(self.request, get_client("Agent")) tasks = task_handler.submit_tasks( @@ -87,13 +65,12 @@ def test_submit_tasks(self): ) assert rpc_called("Agent", "SubmitTasks") - assert tasks is None + assert tasks is not None def test_send_results(self): task_handler = TaskHandler(self.request, get_client("Agent")) - resuls = task_handler.send_results({"result-id": b"result data"}) + task_handler.send_results({"result-id": b"result data"}) assert rpc_called("Agent", "NotifyResultData") - assert resuls is None def test_create_result_metadata(self): task_handler = TaskHandler(self.request, get_client("Agent")) @@ -112,5 +89,5 @@ def test_create_results(self): def test_service_fully_implemented(self): assert all_rpc_called( - "Agent", missings=["GetCommonData", "GetDirectData", "GetResourceData"] + "Agent", missings=["CreateTask", "GetCommonData", "GetDirectData", "GetResourceData"] )