Skip to content

Commit

Permalink
feat: Add Task List and filters to Python (#416)
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem authored Sep 26, 2023
2 parents e6a3c15 + a82e08f commit ceeb413
Show file tree
Hide file tree
Showing 17 changed files with 748 additions and 37 deletions.
18 changes: 17 additions & 1 deletion examples/python/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import grpc
import argparse
from typing import cast
from armonik.client import ArmoniKSubmitter, ArmoniKResult
from armonik.client import ArmoniKSubmitter, ArmoniKResult, ArmoniKTasks
from armonik.client.tasks import TaskFieldFilter
from armonik.common import TaskDefinition, TaskOptions
from datetime import timedelta, datetime
from common import Payload, Result
Expand All @@ -14,6 +15,7 @@ def parse_arguments():
parser.add_argument("-p", "--partition", type=str, help="Partition used for the worker")
parser.add_argument("-v", "--values", type=float, help="List of values to compute instead of x in [0, n[", nargs='+')
parser.add_argument("-n", "--nfirst", type=int, help="Compute from 0 inclusive to n exclusive, n=10 by default", default=10)
parser.add_argument("-l", "--list", action="store_true", help="List tasks of the session at the end")
return parser.parse_args()


Expand Down Expand Up @@ -61,6 +63,20 @@ def main():
# Result is in error
errors = "\n".join(reply.errors)
print(f'Errors : {errors}')

# List tasks
if args.list:
print(f"Listing tasks of session {session_id}")
# Create the tasks client
tasks_client = ArmoniKTasks(channel)

# Request listing of tasks from the session
total_tasks, tasks = tasks_client.list_tasks(TaskFieldFilter.SESSION_ID == session_id)
print(f"Found {total_tasks} tasks in total for the session {session_id}")

for t in tasks:
print(t)

except KeyboardInterrupt:
# If we stop the script, cancel the session
client.cancel_session(session_id)
Expand Down
3 changes: 2 additions & 1 deletion packages/python/proto2python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ mkdir -p $ARMONIK_WORKER $ARMONIK_CLIENT $ARMONIK_COMMON $PACKAGE_PATH
python -m pip install --upgrade pip
python -m venv $PYTHON_VENV
source $PYTHON_VENV/bin/activate
python -m pip install build grpcio grpcio-tools click pytest setuptools_scm[toml]
# We need to fix grpc to 1.56 until this bug is solved : https://github.com/grpc/grpc/issues/34305
python -m pip install build grpcio==1.56.2 grpcio-tools==1.56.2 click pytest setuptools_scm[toml]

unset proto_files
for proto in ${armonik_worker_files[@]}; do
Expand Down
6 changes: 3 additions & 3 deletions packages/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"grpcio",
"grpcio-tools"
"grpcio==1.56.2",
"grpcio-tools==1.56.2"
]

# We need to set grpc to 1.56 until this bug is resolved : https://github.com/grpc/grpc/issues/34305
[project.urls]
"Homepage" = "https://github.com/aneoconsulting/ArmoniK.Api"
"Bug Tracker" = "https://github.com/aneoconsulting/ArmoniK/issues"
Expand Down
5 changes: 3 additions & 2 deletions packages/python/src/armonik/client/results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from grpc import Channel

from typing import List, Dict, cast
Expand All @@ -15,5 +16,5 @@ def __init__(self, grpc_channel: Channel):
"""
self._client = ResultsStub(grpc_channel)

def get_results_ids(self, session_id: str, names : List[str]) -> Dict[str, str]:
return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results}
def get_results_ids(self, session_id: str, names: List[str]) -> Dict[str, str]:
return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results}
13 changes: 7 additions & 6 deletions packages/python/src/armonik/client/submitter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import uuid
from __future__ import annotations
from typing import Optional, List, Tuple, Dict, Union, Generator

from grpc import Channel

from ..common import get_task_filter, TaskOptions, TaskDefinition, Task, TaskStatus, ResultAvailability
from ..common import get_task_filter, TaskOptions, TaskDefinition, Task, ResultAvailability
from ..protogen.client.submitter_service_pb2_grpc import SubmitterStub
from ..protogen.common.objects_pb2 import Empty, TaskRequest, ResultRequest, DataChunk, InitTaskRequest, \
TaskRequestHeader, Configuration, Session, TaskOptions as InnerTaskOptions
from ..protogen.common.submitter_common_pb2 import CreateSessionRequest, GetTaskStatusRequest, CreateLargeTaskRequest, \
WaitRequest
WaitRequest, GetTaskStatusReply
from ..protogen.common.task_status_pb2 import TaskStatus


class ArmoniKSubmitter:
Expand Down Expand Up @@ -135,8 +136,8 @@ def get_task_status(self, task_ids: List[str]) -> Dict[str, TaskStatus]:
"""
request = GetTaskStatusRequest()
request.task_ids.extend(task_ids)
reply = self._client.GetTaskStatus(request)
return {s.task_id: TaskStatus(s.status) for s in reply.id_statuses}
reply: GetTaskStatusReply = self._client.GetTaskStatus(request)
return {s.task_id: s.status for s in reply.id_statuses}

def wait_for_completion(self,
session_ids: Optional[List[str]] = None,
Expand Down Expand Up @@ -167,7 +168,7 @@ def wait_for_completion(self,
Dictionary containing the number of tasks in each status
after waiting for completion
"""
return {TaskStatus(sc.status): sc.count for sc in self._client.WaitForCompletion(
return {sc.status: sc.count for sc in self._client.WaitForCompletion(
WaitRequest(filter=get_task_filter(session_ids, task_ids, included_statuses, excluded_statuses),
stop_on_first_task_error=stop_on_first_task_error,
stop_on_first_task_cancellation=stop_on_first_task_cancellation)).values}
Expand Down
78 changes: 76 additions & 2 deletions packages/python/src/armonik/client/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,56 @@
from __future__ import annotations
from grpc import Channel
from typing import cast, Tuple, List

from ..common import Task
from ..common import Task, Direction
from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter
from ..protogen.client.tasks_service_pb2_grpc import TasksStub
from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse
from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse
from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus
from ..protogen.common.sort_direction_pb2 import SortDirection

from ..protogen.common.tasks_fields_pb2 import *


class TaskFieldFilter:
"""
Enumeration of the available filters
"""
TASK_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_TASK_ID)), rawFilters, rawFilterAnd, rawFilterField)
SESSION_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_SESSION_ID)), rawFilters, rawFilterAnd, rawFilterField)
OWNER_POD_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_OWNER_POD_ID)), rawFilters, rawFilterAnd, rawFilterField)
INITIAL_TASK_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID)), rawFilters, rawFilterAnd, rawFilterField)
STATUS = StatusFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus)
CREATED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_CREATED_AT)), rawFilters, rawFilterAnd, rawFilterField)
SUBMITTED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_SUBMITTED_AT)), rawFilters, rawFilterAnd, rawFilterField)
STARTED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_STARTED_AT)), rawFilters, rawFilterAnd, rawFilterField)
ENDED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_ENDED_AT)), rawFilters, rawFilterAnd, rawFilterField)
POD_TTL = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_POD_TTL)), rawFilters, rawFilterAnd, rawFilterField)
POD_HOSTNAME = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_POD_HOSTNAME)), rawFilters, rawFilterAnd, rawFilterField)
RECEIVED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_RECEIVED_AT)), rawFilters, rawFilterAnd, rawFilterField)
ACQUIRED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_ACQUIRED_AT)), rawFilters, rawFilterAnd, rawFilterField)
ERROR = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_ERROR)), rawFilters, rawFilterAnd, rawFilterField)

MAX_RETRIES = NumberFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_MAX_RETRIES)), rawFilters, rawFilterAnd, rawFilterField)
PRIORITY = NumberFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_PRIORITY)), rawFilters, rawFilterAnd, rawFilterField)
PARTITION_ID = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_PARTITION_ID)), rawFilters, rawFilterAnd, rawFilterField)
APPLICATION_NAME = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_NAME)), rawFilters, rawFilterAnd, rawFilterField)
APPLICATION_VERSION = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION)), rawFilters, rawFilterAnd, rawFilterField)
APPLICATION_NAMESPACE = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_NAMESPACE)), rawFilters, rawFilterAnd, rawFilterField)
APPLICATION_SERVICE = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_SERVICE)), rawFilters, rawFilterAnd, rawFilterField)
ENGINE_TYPE = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_ENGINE_TYPE)), rawFilters, rawFilterAnd, rawFilterField)

@staticmethod
def task_options_key(option_key: str) -> StringFilter:
"""
Filter for the TaskOptions.Options dictionary
Args:
option_key: key in the dictionary
Returns:
Corresponding filter
"""
return StringFilter(TaskField(task_option_generic_field=TaskOptionGenericField(field=option_key)), rawFilters, rawFilterAnd, rawFilterField)


class ArmoniKTasks:
Expand All @@ -25,3 +73,29 @@ def get_task(self, task_id: str) -> Task:
"""
task_response: GetTaskResponse = self._client.GetTask(GetTaskRequest(task_id=task_id))
return Task.from_message(task_response.task)

def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Task]]:
"""List tasks
If the total returned exceeds the requested page size, you may want to use this function again and ask for subsequent pages.
Args:
task_filter: Filter for the tasks to be listed
with_errors: Retrieve the error if the task had errors, defaults to false
page: page number to request, this can be useful when paginating the result, defaults to 0
page_size: size of a page, defaults to 1000
sort_field: field on which to sort the resulting list, defaults to the task_id
sort_direction: direction of the sort, defaults to ascending
Returns:
A tuple containing :
- The total number of tasks for the given filter
- The obtained list of tasks
"""
request = ListTasksRequest(page=page,
page_size=page_size,
filters=cast(rawFilters, task_filter.to_disjunction().to_message()),
sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction),
with_errors=with_errors)
list_response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request)
return list_response.total, [Task.from_message(t) for t in list_response.tasks]
3 changes: 2 additions & 1 deletion packages/python/src/armonik/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter
from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability
from .enumwrapper import HealthCheckStatus, TaskStatus
from .enumwrapper import HealthCheckStatus, TaskStatus, Direction
from .filter import StringFilter, StatusFilter
18 changes: 14 additions & 4 deletions packages/python/src/armonik/common/enumwrapper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import enum
from __future__ import annotations

from ..protogen.common.task_status_pb2 import *
from ..protogen.common.task_status_pb2 import TaskStatus as RawStatus, _TASKSTATUS, TASK_STATUS_CANCELLED, TASK_STATUS_CANCELLING, TASK_STATUS_COMPLETED, TASK_STATUS_CREATING, TASK_STATUS_DISPATCHED, TASK_STATUS_ERROR, TASK_STATUS_PROCESSED, TASK_STATUS_PROCESSING, TASK_STATUS_SUBMITTED, TASK_STATUS_TIMEOUT, TASK_STATUS_UNSPECIFIED
from ..protogen.common.worker_common_pb2 import HealthCheckReply
from ..protogen.common.sort_direction_pb2 import SORT_DIRECTION_ASC, SORT_DIRECTION_DESC

# This file is necessary because the grpc types aren't considered proper types


class HealthCheckStatus(enum.Enum):
class HealthCheckStatus:
NOT_SERVING = HealthCheckReply.NOT_SERVING
SERVING = HealthCheckReply.SERVING
UNKNOWN = HealthCheckReply.UNKNOWN


class TaskStatus(enum.Enum):
class TaskStatus:
@staticmethod
def name_from_value(status: RawStatus) -> str:
return _TASKSTATUS.values_by_number[status].name

CANCELLED = TASK_STATUS_CANCELLED
CANCELLING = TASK_STATUS_CANCELLING
COMPLETED = TASK_STATUS_COMPLETED
Expand All @@ -24,3 +29,8 @@ class TaskStatus(enum.Enum):
SUBMITTED = TASK_STATUS_SUBMITTED
TIMEOUT = TASK_STATUS_TIMEOUT
UNSPECIFIED = TASK_STATUS_UNSPECIFIED


class Direction:
ASC = SORT_DIRECTION_ASC
DESC = SORT_DIRECTION_DESC
Loading

0 comments on commit ceeb413

Please sign in to comment.