Skip to content

Commit

Permalink
feat: Python API update tasks service (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
qdelamea-aneo authored Jan 4, 2024
2 parents e0e3cb8 + 4064ad2 commit b853582
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 17 deletions.
2 changes: 1 addition & 1 deletion packages/python/src/armonik/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .submitter import ArmoniKSubmitter
from .tasks import ArmoniKTasks
from .tasks import ArmoniKTasks, TaskFieldFilter
from .results import ArmoniKResult
from .versions import ArmoniKVersions
108 changes: 97 additions & 11 deletions packages/python/src/armonik/client/tasks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations
from grpc import Channel
from typing import cast, Tuple, List
from typing import cast, Dict, Optional, Tuple, List

from ..common import Task, Direction
from ..common import Task, Direction, TaskDefinition, TaskOptions, TaskStatus
from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter, DurationFilter
from ..protogen.client.tasks_service_pb2_grpc import TasksStub
from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse
from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse, CancelTasksRequest, CancelTasksResponse, GetResultIdsRequest, GetResultIdsResponse, SubmitTasksRequest, SubmitTasksResponse, CountTasksByStatusRequest, CountTasksByStatusResponse, ListTasksResponse
from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus
from ..protogen.common.sort_direction_pb2 import SortDirection

from ..protogen.common.tasks_fields_pb2 import *
from ..common.helpers import batched


class TaskFieldFilter:
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_task(self, task_id: str) -> Task:
task_response: GetTaskResponse = self._client.GetTask(GetTaskRequest(task_id=task_id))
return Task.from_message(task_response.task)

def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Task]]:
def list_tasks(self, task_filter: Filter | None = None, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC, detailed: bool = True) -> Tuple[int, List[Task]]:
"""List tasks
If the total returned exceeds the requested page size, you may want to use this function again and ask for subsequent pages.
Expand All @@ -89,16 +89,102 @@ def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int =
page_size: size of a page, defaults to 1000
sort_field: field on which to sort the resulting list, defaults to the task_id
sort_direction: direction of the sort, defaults to ascending
detailed: Wether to retrieve the detailed description of the task.
Returns:
A tuple containing :
- The total number of tasks for the given filter
- The obtained list of tasks
"""
request = ListTasksRequest(page=page,
page_size=page_size,
filters=cast(rawFilters, task_filter.to_disjunction().to_message()),
sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction),
with_errors=with_errors)
list_response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request)
return list_response.total, [Task.from_message(t) for t in list_response.tasks]
page_size=page_size,
filters=cast(rawFilters, task_filter.to_disjunction().to_message()) if task_filter else None,
sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction),
with_errors=with_errors
)
if detailed:
response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request)
return response.total, [Task.from_message(t) for t in response.tasks]
response: ListTasksResponse = self._client.ListTasks(request)
return response.total, [Task.from_message(t) for t in response.tasks]

def cancel_tasks(self, task_ids: List[str], chunk_size: Optional[int] = 500):
"""Cancel tasks.
Args:
task_ids: IDs of the tasks.
chunk_size: Batch size for cancelling.
Return:
The list of cancelled tasks.
"""
for task_id_batch in batched(task_ids, chunk_size):
request = CancelTasksRequest(task_ids=task_id_batch)
self._client.CancelTasks(request)

def get_result_ids(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> Dict[str, List[str]]:
"""Get result IDs of a list of tasks.
Args:
task_ids: The IDs of the tasks.
chunk_size: Batch size for retrieval.
Return:
A dictionary mapping the ID of a task to the IDs of its results..
"""
tasks_result_ids = {}

for task_ids_batch in batched(task_ids, chunk_size):
request = GetResultIdsRequest(task_id=task_ids_batch)
result_ids_response: GetResultIdsResponse = self._client.GetResultIds(request)
for t in result_ids_response.task_results:
tasks_result_ids[t.task_id] = list(t.result_ids)
return tasks_result_ids

def count_tasks_by_status(self, task_filter: Filter | None = None) -> Dict[TaskStatus, int]:
"""Get number of tasks by status.
Args:
task_filter: Filter for the tasks to be listed
Return:
A dictionnary mapping each status to the number of filtered tasks.
"""
request = CountTasksByStatusRequest(
filters=cast(rawFilters, task_filter.to_disjunction().to_message()) if task_filter else None
)
count_tasks_by_status_response: CountTasksByStatusResponse = self._client.CountTasksByStatus(request)
return {TaskStatus(status_count.status): status_count.count for status_count in count_tasks_by_status_response.status}

def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions | None] = None, chunk_size: Optional[int] = 100) -> List[Task]:
"""Submit tasks to ArmoniK.
Args:
session_id: Session Id
tasks: List of task definitions
default_task_options: Default Task Options used if a task has its options not set
chunk_size: Batch size for submission
Returns:
Tuple containing the list of successfully sent tasks, and
the list of submission errors if any
"""
for tasks_batch in batched(tasks, chunk_size):
task_creations = []

for t in tasks_batch:
task_creation = SubmitTasksRequest.TaskCreation(
expected_output_keys=t.expected_output_ids,
payload_id=t.payload_id,
data_dependencies=t.data_dependencies,
task_options=t.options.to_message() if t.options else None
)
task_creations.append(task_creation)

request = SubmitTasksRequest(
session_id=session_id,
task_creations=task_creations,
task_options=default_task_options.to_message() if default_task_options else None
)

self._client.SubmitTasks(request)
6 changes: 5 additions & 1 deletion packages/python/src/armonik/common/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..protogen.common.objects_pb2 import Empty, Output as WorkerOutput, TaskOptions as RawTaskOptions
from ..protogen.common.task_status_pb2 import TaskStatus as RawTaskStatus
from .enumwrapper import TaskStatus, SessionStatus, ResultStatus
from ..protogen.common.partitions_common_pb2 import PartitionRaw
from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus
from ..protogen.common.sessions_common_pb2 import SessionRaw
from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus
Expand Down Expand Up @@ -70,9 +71,11 @@ def to_message(self):

@dataclass()
class TaskDefinition:
payload: bytes
payload_id: str = field(default_factory=str)
payload: bytes = field(default_factory=bytes)
expected_output_ids: List[str] = field(default_factory=list)
data_dependencies: List[str] = field(default_factory=list)
options: Optional[TaskOptions] = None

def __post_init__(self):
if len(self.expected_output_ids) <= 0:
Expand All @@ -89,6 +92,7 @@ class Task:
expected_output_ids: List[str] = field(default_factory=list)
retry_of_ids: List[str] = field(default_factory=list)
status: RawTaskStatus = TaskStatus.UNSPECIFIED
payload_id: Optional[str] = None
status_message: Optional[str] = None
options: Optional[TaskOptions] = None
created_at: Optional[datetime] = None
Expand Down
10 changes: 6 additions & 4 deletions packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
import requests

from armonik.client import ArmoniKVersions
from armonik.client import ArmoniKTasks, ArmoniKVersions
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from typing import List
from typing import List, Union


# Mock server endpoints used for the tests.
Expand Down Expand Up @@ -54,7 +54,7 @@ def clean_up(request):
print("An error occurred when resetting the server: " + str(e))


def get_client(client_name: str, endpoint: str = grpc_endpoint) -> ArmoniKVersions:
def get_client(client_name: str, endpoint: str = grpc_endpoint) -> Union[ArmoniKTasks, ArmoniKVersions]:
"""
Get the ArmoniK client instance based on the specified service name.
Expand All @@ -63,7 +63,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> ArmoniKVersio
endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint.
Returns:
ArmoniKVersions
Union[ArmoniKTasks, ArmoniKVersions]
An instance of the specified ArmoniK client.
Raises:
Expand All @@ -75,6 +75,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> ArmoniKVersio
"""
channel = grpc.insecure_channel(endpoint).__enter__()
match client_name:
case "Tasks":
return ArmoniKTasks(channel)
case "Versions":
return ArmoniKVersions(channel)
case _:
Expand Down
123 changes: 123 additions & 0 deletions packages/python/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import datetime

from .conftest import all_rpc_called, rpc_called, get_client
from armonik.client import ArmoniKTasks, TaskFieldFilter
from armonik.common import Task, TaskDefinition, TaskOptions, TaskStatus, Output


class TestArmoniKTasks:

def test_get_task(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
task = tasks_client.get_task("task-id")

assert rpc_called("Tasks", "GetTask")
assert isinstance(task, Task)
assert task.id == 'task-id'
assert task.session_id == 'session-id'
assert task.data_dependencies == []
assert task.expected_output_ids == []
assert task.retry_of_ids == []
assert task.status == TaskStatus.COMPLETED
assert task.payload_id is None
assert task.status_message == ''
assert task.options == TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1,
partition_id='partition-id',
application_name='application-name',
application_version='application-version',
application_namespace='application-namespace',
application_service='application-service',
engine_type='engine-type',
options={}
)
assert task.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert task.submitted_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert task.started_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert task.ended_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert task.pod_ttl == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert task.output == Output(error='')
assert task.pod_hostname == ''
assert task.received_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert task.acquired_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)

def test_list_tasks_detailed_no_filter(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
num, tasks = tasks_client.list_tasks()
assert rpc_called("Tasks", "ListTasksDetailed")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert tasks == []

def test_list_tasks_detailed_with_filter(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
num, tasks = tasks_client.list_tasks(TaskFieldFilter.STATUS == TaskStatus.COMPLETED)
assert rpc_called("Tasks", "ListTasksDetailed", 2)
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert tasks == []

def test_list_tasks_no_detailed_no_filter(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
num, tasks = tasks_client.list_tasks(detailed=False)
assert rpc_called("Tasks", "ListTasks")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert num == 0
assert tasks == []

def test_cancel_tasks(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
tasks = tasks_client.cancel_tasks(["task-id-1", "task-id-2"])

assert rpc_called("Tasks", "CancelTasks")
assert tasks is None

def test_get_result_ids(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
tasks_results = tasks_client.get_result_ids(["task-id-1", "task-id-2"])
assert rpc_called("Tasks", "GetResultIds")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert tasks_results == {}

def test_count_tasks_by_status_no_filter(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
count = tasks_client.count_tasks_by_status()
assert rpc_called("Tasks", "CountTasksByStatus")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert count == {}

def test_count_tasks_by_status_with_filter(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
count = tasks_client.count_tasks_by_status(TaskFieldFilter.STATUS == TaskStatus.COMPLETED)
assert rpc_called("Tasks", "CountTasksByStatus", 2)
# TODO: Mock must be updated to return something and so that changes the following assertions
assert count == {}

def test_submit_tasks(self):
tasks_client: ArmoniKTasks = get_client("Tasks")
tasks = tasks_client.submit_tasks(
"session-id",
[TaskDefinition(payload_id="payload-id",
expected_output_ids=["result-id"],
data_dependencies=[],
options=TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1
)
)
],
default_task_options=TaskOptions(
max_duration=datetime.timedelta(seconds=1),
priority=1,
max_retries=1
)
)
assert rpc_called("Tasks", "SubmitTasks")
# TODO: Mock must be updated to return something and so that changes the following assertions
assert tasks is None

def test_service_fully_implemented(self):
assert all_rpc_called("Tasks")

0 comments on commit b853582

Please sign in to comment.