Skip to content

Commit

Permalink
feat: Removing deprecated task handler function
Browse files Browse the repository at this point in the history
fix: Fixed result parsing in task handler
  • Loading branch information
dbrasseur-aneo committed Aug 23, 2024
1 parent 07b4e6e commit 9cfc4b2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 167 deletions.
11 changes: 11 additions & 0 deletions packages/python/src/armonik/common/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from deprecation import deprecated

from ..protogen.common.agent_common_pb2 import ResultMetaData
from ..protogen.common.applications_common_pb2 import ApplicationRaw
from ..protogen.common.tasks_common_pb2 import TaskDetailed
from .filter import (
Expand Down Expand Up @@ -455,6 +456,16 @@ def from_message(cls, result_raw: ResultRaw) -> "Result":
size=result_raw.size,
)

@classmethod
def from_result_metadata(cls, result_metadata: ResultMetaData) -> "Result":
return cls(
session_id=result_metadata.session_id,
name=result_metadata.name,
status=result_metadata.status,
created_at=timestamp_to_datetime(result_metadata.created_at),
result_id=result_metadata.result_id,
)

def __eq__(self, other: "Result") -> bool:
return (
self.session_id == other.session_id
Expand Down
167 changes: 26 additions & 141 deletions packages/python/src/armonik/worker/taskhandler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import os
from deprecation import deprecated
from typing import Optional, Dict, List, Tuple, Union
from typing import Optional, Dict, List, Union

from ..common import TaskOptions, TaskDefinition, Task, Result
from ..common import TaskOptions, TaskDefinition, Result, Task
from ..common.helpers import batched
from ..protogen.common.agent_common_pb2 import (
CreateTaskRequest,
CreateResultsMetaDataRequest,
CreateResultsMetaDataResponse,
NotifyResultDataRequest,
Expand All @@ -14,15 +14,10 @@
SubmitTasksRequest,
)
from ..protogen.common.objects_pb2 import (
TaskRequest,
DataChunk,
InitTaskRequest,
TaskRequestHeader,
Configuration,
)
from ..protogen.worker.agent_service_pb2_grpc import AgentStub
from ..protogen.common.worker_common_pb2 import ProcessRequest
from ..common.helpers import batched
from ..protogen.worker.agent_service_pb2_grpc import AgentStub


class TaskHandler:
Expand All @@ -47,79 +42,20 @@ def __init__(self, request: ProcessRequest, agent_client: AgentStub):
with open(os.path.join(self.data_folder, dd), "rb") as f:
self.data_dependencies[dd] = f.read()

@deprecated(
deprecated_in="3.15.0",
details="Use submit_tasks and instead and create the payload using create_result_metadata and send_result",
)
def create_tasks(
self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None
) -> Tuple[List[Task], List[str]]:
"""Create new tasks for ArmoniK
Args:
tasks: List of task definitions
task_options: Task Options used for this batch of tasks
Returns:
Tuple containing the list of successfully sent tasks, and
the list of submission errors if any
"""
task_requests = []

for t in tasks:
task_request = TaskRequest()
task_request.expected_output_keys.extend(t.expected_output_ids)
task_request.data_dependencies.extend(t.data_dependencies)
task_request.payload = t.payload
task_requests.append(task_request)
assert self.configuration is not None
create_tasks_reply = self._client.CreateTask(
_to_request_stream(
task_requests,
self.token,
task_options.to_message() if task_options is not None else None,
self.configuration.data_chunk_max_size,
)
)
ret = create_tasks_reply.WhichOneof("Response")
if ret is None or ret == "error":
raise Exception(f"Issue with server when submitting tasks : {create_tasks_reply.error}")
elif ret == "creation_status_list":
tasks_created = []
tasks_creation_failed = []
for creation_status in create_tasks_reply.creation_status_list.creation_statuses:
if creation_status.WhichOneof("Status") == "task_info":
tasks_created.append(
Task(
id=creation_status.task_info.task_id,
session_id=self.session_id,
expected_output_ids=[
k for k in creation_status.task_info.expected_output_keys
],
data_dependencies=[
k for k in creation_status.task_info.data_dependencies
],
)
)
else:
tasks_creation_failed.append(creation_status.error)
else:
raise Exception("Unknown value")
return tasks_created, tasks_creation_failed

def submit_tasks(
self,
tasks: List[TaskDefinition],
default_task_options: Optional[TaskOptions] = None,
batch_size: Optional[int] = 100,
) -> None:
) -> List[Task]:
"""Submit tasks to the agent.
Args:
tasks: List of task definitions
default_task_options: Default Task Options used if a task has its options not set
batch_size: Batch size for submission
"""
submitted_tasks: List[Task] = []
for tasks_batch in batched(tasks, batch_size):
task_creations = []

Expand All @@ -142,13 +78,23 @@ def submit_tasks(
if default_task_options:
request.task_options = (default_task_options.to_message(),)

self._client.SubmitTasks(request)
submitted_tasks.extend(
Task(
id=t.task_id,
expected_output_ids=list(t.expected_output_ids),
data_dependencies=list(t.data_dependencies),
session_id=self.session_id,
payload_id=self.payload_id,
)
for t in self._client.SubmitTasks(request).task_infos
)
return submitted_tasks

def send_results(self, results_data: Dict[str, Union[bytes, bytearray]]) -> None:
"""Send results.
Args:
result_data: A dictionnary mapping each result ID to its data.
results_data: A dictionary mapping each result ID to its data.
"""
for result_id, result_data in results_data.items():
with open(os.path.join(self.data_folder, result_id), "wb") as f:
Expand All @@ -167,7 +113,7 @@ def send_results(self, results_data: Dict[str, Union[bytes, bytearray]]) -> None

def create_results_metadata(
self, result_names: List[str], batch_size: int = 100
) -> Dict[str, List[Result]]:
) -> Dict[str, Result]:
"""
Create the metadata of multiple results at once.
Data have to be uploaded separately.
Expand All @@ -177,21 +123,21 @@ def create_results_metadata(
batch_size: Batch size for querying.
Return:
A dictionnary mapping each result name to its result summary.
A dictionary mapping each result name to its result summary.
"""
results = {}
for result_names_batch in batched(result_names, batch_size):
request = CreateResultsMetaDataRequest(
results=[
CreateResultsMetaDataRequest.ResultCreate(name=result_name)
for result_name in result_names
for result_name in result_names_batch
],
session_id=self.session_id,
communication_token=self.token,
)
response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request)
for result_message in response.results:
results[result_message.name] = Result.from_message(result_message)
results[result_message.name] = Result.from_result_metadata(result_message)
return results

def create_results(
Expand All @@ -200,11 +146,11 @@ def create_results(
"""Create one result with data included in the request.
Args:
results_data: A dictionnary mapping the result names to their actual data.
results_data: A dictionary mapping the result names to their actual data.
batch_size: Batch size for querying.
Return:
A dictionnary mappin each result name to its corresponding result summary.
A dictionary mapping each result name to its corresponding result summary.
"""
results = {}
for results_ids_batch in batched(results_data.keys(), batch_size):
Expand All @@ -218,66 +164,5 @@ def create_results(
)
response: CreateResultsResponse = self._client.CreateResults(request)
for message in response.results:
results[message.name] = Result.from_message(message)
results[message.name] = Result.from_result_metadata(message)
return results


def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size):
req = CreateTaskRequest(
init_task=InitTaskRequest(
header=TaskRequestHeader(
data_dependencies=request.data_dependencies,
expected_output_keys=request.expected_output_keys,
)
),
communication_token=communication_token,
)
yield req
start = 0
payload_length = len(request.payload)
if payload_length == 0:
req = CreateTaskRequest(
task_payload=DataChunk(data=b""), communication_token=communication_token
)
yield req
while start < payload_length:
chunk_size = min(chunk_max_size, payload_length - start)
req = CreateTaskRequest(
task_payload=DataChunk(data=request.payload[start : start + chunk_size]),
communication_token=communication_token,
)
yield req
start += chunk_size
req = CreateTaskRequest(
task_payload=DataChunk(data_complete=True),
communication_token=communication_token,
)
yield req

if is_last:
req = CreateTaskRequest(
init_task=InitTaskRequest(last_task=True),
communication_token=communication_token,
)
yield req


def _to_request_stream(requests, communication_token, t_options, chunk_max_size):
if t_options is None:
req = CreateTaskRequest(
init_request=CreateTaskRequest.InitRequest(),
communication_token=communication_token,
)
else:
req = CreateTaskRequest(
init_request=CreateTaskRequest.InitRequest(task_options=t_options),
communication_token=communication_token,
)
yield req
if len(requests) == 0:
return
for r in requests[:-1]:
for req in _to_request_stream_internal(r, communication_token, False, chunk_max_size):
yield req
for req in _to_request_stream_internal(requests[-1], communication_token, True, chunk_max_size):
yield req
29 changes: 3 additions & 26 deletions packages/python/tests/test_taskhandler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import logging
import warnings

from .conftest import all_rpc_called, rpc_called, get_client, data_folder
from armonik.common import TaskDefinition, TaskOptions
Expand Down Expand Up @@ -53,27 +52,6 @@ def test_taskhandler_init(self):
assert task_handler.payload == "payload".encode()
assert task_handler.data_dependencies == {"dd-id": "dd".encode()}

def test_create_task(self):
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")

task_handler = TaskHandler(self.request, get_client("Agent"))
tasks, errors = task_handler.create_tasks(
[
TaskDefinition(
payload=b"payload",
expected_output_ids=["result-id"],
data_dependencies=[],
)
]
)

assert issubclass(w[-1].category, DeprecationWarning)
assert rpc_called("Agent", "CreateTask")
assert tasks == []
assert errors == []

def test_submit_tasks(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
tasks = task_handler.submit_tasks(
Expand All @@ -87,13 +65,12 @@ def test_submit_tasks(self):
)

assert rpc_called("Agent", "SubmitTasks")
assert tasks is None
assert tasks is not None

def test_send_results(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
resuls = task_handler.send_results({"result-id": b"result data"})
task_handler.send_results({"result-id": b"result data"})
assert rpc_called("Agent", "NotifyResultData")
assert resuls is None

def test_create_result_metadata(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
Expand All @@ -112,5 +89,5 @@ def test_create_results(self):

def test_service_fully_implemented(self):
assert all_rpc_called(
"Agent", missings=["GetCommonData", "GetDirectData", "GetResourceData"]
"Agent", missings=["CreateTask", "GetCommonData", "GetDirectData", "GetResourceData"]
)

0 comments on commit 9cfc4b2

Please sign in to comment.