Skip to content

Commit

Permalink
feat: Python API update taskhandler and worker (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
qdelamea-aneo authored Jan 5, 2024
2 parents e90664c + e3dda39 commit 3f20e50
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 13 deletions.
6 changes: 6 additions & 0 deletions packages/python/src/armonik/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .worker import ArmoniKWorker
from .taskhandler import TaskHandler
from .seqlogger import ClefLogger

__all__ = [
'ArmoniKWorker',
'TaskHandler',
'ClefLogger',
]
103 changes: 92 additions & 11 deletions packages/python/src/armonik/worker/taskhandler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations
import os
from deprecation import deprecated
from typing import Optional, Dict, List, Tuple, Union, cast

from ..common import TaskOptions, TaskDefinition, Task
from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest
from ..common import TaskOptions, TaskDefinition, Task, Result
from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest, CreateResultsRequest, CreateResultsResponse, SubmitTasksRequest, SubmitTasksResponse
from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration
from ..protogen.worker.agent_service_pb2_grpc import AgentStub
from ..protogen.common.worker_common_pb2 import ProcessRequest
from ..common.helpers import batched


class TaskHandler:
Expand All @@ -31,6 +33,7 @@ def __init__(self, request: ProcessRequest, agent_client: AgentStub):
with open(os.path.join(self.data_folder, dd), "rb") as f:
self.data_dependencies[dd] = f.read()

@deprecated(deprecated_in="3.15.0", details="Use submit_tasks and instead and create the payload using create_result_metadata and send_result")
def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]:
"""Create new tasks for ArmoniK
Expand Down Expand Up @@ -67,21 +70,99 @@ def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskO
raise Exception("Unknown value")
return tasks_created, tasks_creation_failed

def send_result(self, key: str, data: Union[bytes, bytearray]) -> None:
""" Send task result
def submit_tasks(self, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, batch_size: Optional[int] = 100) -> None:
"""Submit tasks to the agent.
Args:
key: Result key
data: Result data
tasks: List of task definitions
default_task_options: Default Task Options used if a task has its options not set
batch_size: Batch size for submission
"""
with open(os.path.join(self.data_folder, key), "wb") as f:
f.write(data)
for tasks_batch in batched(tasks, batch_size):
task_creations = []

for t in tasks_batch:
task_creation = SubmitTasksRequest.TaskCreation(
expected_output_keys=t.expected_output_ids,
payload_id=t.payload_id,
data_dependencies=t.data_dependencies
)
if t.options:
task_creation.task_options=t.options.to_message()
task_creations.append(task_creation)

request = SubmitTasksRequest(
session_id=self.session_id,
communication_token=self.token,
task_creations=task_creations
)

if default_task_options:
request.task_options=default_task_options.to_message(),

self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token))
self._client.SubmitTasks(request)

def get_results_ids(self, names: List[str]) -> Dict[str, str]:
return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results}
def send_results(self, results_data: Dict[str, bytes | bytearray]) -> None:
"""Send results.
Args:
result_data: A dictionnary mapping each result ID to its data.
"""
for result_id, result_data in results_data.items():
with open(os.path.join(self.data_folder, result_id), "wb") as f:
f.write(result_data)

request = NotifyResultDataRequest(
ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=result_id) for result_id in results_data.keys()],
communication_token=self.token
)
self._client.NotifyResultData(request)

def create_results_metadata(self, result_names: List[str], batch_size: int = 100) -> Dict[str, List[Result]]:
"""
Create the metadata of multiple results at once.
Data have to be uploaded separately.
Args:
result_names: The names of the results to create.
batch_size: Batch size for querying.
Return:
A dictionnary mapping each result name to its result summary.
"""
results = {}
for result_names_batch in batched(result_names, batch_size):
request = CreateResultsMetaDataRequest(
results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names],
session_id=self.session_id,
communication_token=self.token
)
response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request)
for result_message in response.results:
results[result_message.name] = Result.from_message(result_message)
return results

def create_results(self, results_data: Dict[str, bytes], batch_size: int = 1) -> Dict[str, Result]:
"""Create one result with data included in the request.
Args:
results_data: A dictionnary mapping the result names to their actual data.
batch_size: Batch size for querying.
Return:
A dictionnary mappin each result name to its corresponding result summary.
"""
results = {}
for results_ids_batch in batched(results_data.keys(), batch_size):
request = CreateResultsRequest(
results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_ids_batch],
session_id=self.session_id,
communication_token=self.token
)
response: CreateResultsResponse = self._client.CreateResults(request)
for message in response.results:
results[message.name] = Result.from_message(message)
return results

def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size):
req = CreateTaskRequest(
Expand Down
7 changes: 5 additions & 2 deletions packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def clean_up(request):
# Remove the temporary files created for testing
os.remove(os.path.join(data_folder, "payload-id"))
os.remove(os.path.join(data_folder, "dd-id"))
os.remove(os.path.join(data_folder, "result-id"))

# Reset the mock server counters
try:
Expand All @@ -54,7 +55,7 @@ def clean_up(request):
print("An error occurred when resetting the server: " + str(e))


def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]:
def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[AgentStub, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]:
"""
Get the ArmoniK client instance based on the specified service name.
Expand All @@ -63,7 +64,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK
endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint.
Returns:
Union[ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]
Union[AgentStub, ArmoniKPartitions, ArmoniKResults, ArmoniKSessions, ArmoniKTasks, ArmoniKVersions]
An instance of the specified ArmoniK client.
Raises:
Expand All @@ -75,6 +76,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniK
"""
channel = grpc.insecure_channel(endpoint).__enter__()
match client_name:
case "Agent":
return AgentStub(channel)
case "Partitions":
return ArmoniKPartitions(channel)
case "Results":
Expand Down
108 changes: 108 additions & 0 deletions packages/python/tests/test_taskhandler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import datetime
import logging
import warnings

from .conftest import all_rpc_called, rpc_called, get_client, data_folder
from armonik.common import TaskDefinition, TaskOptions
from armonik.worker import TaskHandler
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from armonik.protogen.common.worker_common_pb2 import ProcessRequest
from armonik.protogen.common.objects_pb2 import Configuration


logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)


class TestTaskHandler:

request =ProcessRequest(
communication_token="token",
session_id="session-id",
task_id="task-id",
expected_output_keys=["result-id"],
payload_id="payload-id",
data_dependencies=["dd-id"],
data_folder=data_folder,
configuration=Configuration(data_chunk_max_size=8000),
task_options=TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1
).to_message()
)

def test_taskhandler_init(self):
task_handler = TaskHandler(self.request, get_client("Agent"))

assert task_handler.session_id == "session-id"
assert task_handler.task_id == "task-id"
assert task_handler.task_options == TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1,
partition_id='',
application_name='',
application_version='',
application_namespace='',
application_service='',
engine_type='',
options={}
)
assert task_handler.token == "token"
assert task_handler.expected_results == ["result-id"]
assert task_handler.configuration == Configuration(data_chunk_max_size=8000)
assert task_handler.payload_id == "payload-id"
assert task_handler.data_folder == data_folder
assert task_handler.payload == "payload".encode()
assert task_handler.data_dependencies == {"dd-id": "dd".encode()}

def test_create_task(self):
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")

task_handler = TaskHandler(self.request, get_client("Agent"))
tasks, errors = task_handler.create_tasks([TaskDefinition(
payload=b"payload",
expected_output_ids=["result-id"],
data_dependencies=[])])

assert issubclass(w[-1].category, DeprecationWarning)
assert rpc_called("Agent", "CreateTask")
assert tasks == []
assert errors == []

def test_submit_tasks(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
tasks = task_handler.submit_tasks([TaskDefinition(payload_id="payload-id",
expected_output_ids=["result-id"],
data_dependencies=[])]
)

assert rpc_called("Agent", "SubmitTasks")
assert tasks is None

def test_send_results(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
resuls = task_handler.send_results({"result-id": b"result data"})
assert rpc_called("Agent", "NotifyResultData")
assert resuls is None

def test_create_result_metadata(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
results = task_handler.create_results_metadata(["result-name"])

assert rpc_called("Agent", "CreateResultsMetaData")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert results == {}

def test_create_results(self):
task_handler = TaskHandler(self.request, get_client("Agent"))
results = task_handler.create_results({"result-name": b"test data"})

assert rpc_called("Agent", "CreateResults")
assert results == {}

def test_service_fully_implemented(self):
assert all_rpc_called("Agent", missings=["GetCommonData", "GetDirectData", "GetResourceData"])
80 changes: 80 additions & 0 deletions packages/python/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import datetime
import grpc
import logging
import os
import pytest

from .conftest import data_folder, grpc_endpoint
from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger
from armonik.common import Output, TaskOptions
from armonik.protogen.common.objects_pb2 import Empty, Configuration
from armonik.protogen.common.worker_common_pb2 import ProcessRequest


def do_nothing(_: TaskHandler) -> Output:
return Output()


def throw_error(_: TaskHandler) -> Output:
raise ValueError("TestError")


def return_error(_: TaskHandler) -> Output:
return Output("TestError")


def return_and_send(th: TaskHandler) -> Output:
th.send_results({th.expected_results[0]: b"result"})
return Output()


class TestWorker:

request = ProcessRequest(
communication_token="token",
session_id="session-id",
task_id="task-id",
expected_output_keys=["result-id"],
payload_id="payload-id",
data_dependencies=["dd-id"],
data_folder=data_folder,
configuration=Configuration(data_chunk_max_size=8000),
task_options=TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1
).to_message()
)

def test_do_nothing(self):
with grpc.insecure_channel(grpc_endpoint) as agent_channel:
worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL))
reply = worker.Process(self.request, None)
assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success
worker.HealthCheck(Empty(), None)

def test_should_return_none(self):
with grpc.insecure_channel(grpc_endpoint) as agent_channel:
worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL))
reply = worker.Process(self.request, None)
assert reply is None

def test_should_error(self):
with grpc.insecure_channel(grpc_endpoint) as agent_channel:
worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL))
reply = worker.Process(self.request, None)
output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None)
assert not output.success
assert output.error == "TestError"

def test_should_write_result(self):
with grpc.insecure_channel(grpc_endpoint) as agent_channel:
worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG))
reply = worker.Process(self.request, None)
assert reply is not None
output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None)
assert output.success
assert os.path.exists(os.path.join(data_folder, self.request.expected_output_keys[0]))
with open(os.path.join(data_folder, self.request.expected_output_keys[0]), "rb") as f:
value = f.read()
assert len(value) > 0

0 comments on commit 3f20e50

Please sign in to comment.