diff --git a/examples/python/client.py b/examples/python/client.py index a996de3b5..399fc0018 100644 --- a/examples/python/client.py +++ b/examples/python/client.py @@ -2,7 +2,8 @@ import grpc import argparse from typing import cast -from armonik.client import ArmoniKSubmitter, ArmoniKResult +from armonik.client import ArmoniKSubmitter, ArmoniKResult, ArmoniKTasks +from armonik.client.tasks import TaskFieldFilter from armonik.common import TaskDefinition, TaskOptions from datetime import timedelta, datetime from common import Payload, Result @@ -14,6 +15,7 @@ def parse_arguments(): parser.add_argument("-p", "--partition", type=str, help="Partition used for the worker") parser.add_argument("-v", "--values", type=float, help="List of values to compute instead of x in [0, n[", nargs='+') parser.add_argument("-n", "--nfirst", type=int, help="Compute from 0 inclusive to n exclusive, n=10 by default", default=10) + parser.add_argument("-l", "--list", action="store_true", help="List tasks of the session at the end") return parser.parse_args() @@ -61,6 +63,20 @@ def main(): # Result is in error errors = "\n".join(reply.errors) print(f'Errors : {errors}') + + # List tasks + if args.list: + print(f"Listing tasks of session {session_id}") + # Create the tasks client + tasks_client = ArmoniKTasks(channel) + + # Request listing of tasks from the session + total_tasks, tasks = tasks_client.list_tasks(TaskFieldFilter.SESSION_ID == session_id) + print(f"Found {total_tasks} tasks in total for the session {session_id}") + + for t in tasks: + print(t) + except KeyboardInterrupt: # If we stop the script, cancel the session client.cancel_session(session_id) diff --git a/packages/python/proto2python.sh b/packages/python/proto2python.sh index 786a047c1..c2fe45244 100644 --- a/packages/python/proto2python.sh +++ b/packages/python/proto2python.sh @@ -32,7 +32,8 @@ mkdir -p $ARMONIK_WORKER $ARMONIK_CLIENT $ARMONIK_COMMON $PACKAGE_PATH python -m pip install --upgrade pip python -m venv $PYTHON_VENV source $PYTHON_VENV/bin/activate -python -m pip install build grpcio grpcio-tools click pytest setuptools_scm[toml] +# We need to fix grpc to 1.56 until this bug is solved : https://github.com/grpc/grpc/issues/34305 +python -m pip install build grpcio==1.56.2 grpcio-tools==1.56.2 click pytest setuptools_scm[toml] unset proto_files for proto in ${armonik_worker_files[@]}; do diff --git a/packages/python/pyproject.toml b/packages/python/pyproject.toml index 98dfbf03a..0b66674bc 100644 --- a/packages/python/pyproject.toml +++ b/packages/python/pyproject.toml @@ -15,10 +15,10 @@ classifiers = [ "Programming Language :: Python :: 3", ] dependencies = [ - "grpcio", - "grpcio-tools" + "grpcio==1.56.2", + "grpcio-tools==1.56.2" ] - +# We need to set grpc to 1.56 until this bug is resolved : https://github.com/grpc/grpc/issues/34305 [project.urls] "Homepage" = "https://github.com/aneoconsulting/ArmoniK.Api" "Bug Tracker" = "https://github.com/aneoconsulting/ArmoniK/issues" diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index b60bc4967..642fc346f 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -1,3 +1,4 @@ +from __future__ import annotations from grpc import Channel from typing import List, Dict, cast @@ -15,5 +16,5 @@ def __init__(self, grpc_channel: Channel): """ self._client = ResultsStub(grpc_channel) - def get_results_ids(self, session_id: str, names : List[str]) -> Dict[str, str]: - return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results} \ No newline at end of file + def get_results_ids(self, session_id: str, names: List[str]) -> Dict[str, str]: + return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results} diff --git a/packages/python/src/armonik/client/submitter.py b/packages/python/src/armonik/client/submitter.py index 0623555c9..3b9e402d2 100644 --- a/packages/python/src/armonik/client/submitter.py +++ b/packages/python/src/armonik/client/submitter.py @@ -1,14 +1,15 @@ -import uuid +from __future__ import annotations from typing import Optional, List, Tuple, Dict, Union, Generator from grpc import Channel -from ..common import get_task_filter, TaskOptions, TaskDefinition, Task, TaskStatus, ResultAvailability +from ..common import get_task_filter, TaskOptions, TaskDefinition, Task, ResultAvailability from ..protogen.client.submitter_service_pb2_grpc import SubmitterStub from ..protogen.common.objects_pb2 import Empty, TaskRequest, ResultRequest, DataChunk, InitTaskRequest, \ TaskRequestHeader, Configuration, Session, TaskOptions as InnerTaskOptions from ..protogen.common.submitter_common_pb2 import CreateSessionRequest, GetTaskStatusRequest, CreateLargeTaskRequest, \ - WaitRequest + WaitRequest, GetTaskStatusReply +from ..protogen.common.task_status_pb2 import TaskStatus class ArmoniKSubmitter: @@ -135,8 +136,8 @@ def get_task_status(self, task_ids: List[str]) -> Dict[str, TaskStatus]: """ request = GetTaskStatusRequest() request.task_ids.extend(task_ids) - reply = self._client.GetTaskStatus(request) - return {s.task_id: TaskStatus(s.status) for s in reply.id_statuses} + reply: GetTaskStatusReply = self._client.GetTaskStatus(request) + return {s.task_id: s.status for s in reply.id_statuses} def wait_for_completion(self, session_ids: Optional[List[str]] = None, @@ -167,7 +168,7 @@ def wait_for_completion(self, Dictionary containing the number of tasks in each status after waiting for completion """ - return {TaskStatus(sc.status): sc.count for sc in self._client.WaitForCompletion( + return {sc.status: sc.count for sc in self._client.WaitForCompletion( WaitRequest(filter=get_task_filter(session_ids, task_ids, included_statuses, excluded_statuses), stop_on_first_task_error=stop_on_first_task_error, stop_on_first_task_cancellation=stop_on_first_task_cancellation)).values} diff --git a/packages/python/src/armonik/client/tasks.py b/packages/python/src/armonik/client/tasks.py index 1696cd06c..d6fb6876a 100644 --- a/packages/python/src/armonik/client/tasks.py +++ b/packages/python/src/armonik/client/tasks.py @@ -1,8 +1,56 @@ +from __future__ import annotations from grpc import Channel +from typing import cast, Tuple, List -from ..common import Task +from ..common import Task, Direction +from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter from ..protogen.client.tasks_service_pb2_grpc import TasksStub -from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse +from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse +from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus +from ..protogen.common.sort_direction_pb2 import SortDirection + +from ..protogen.common.tasks_fields_pb2 import * + + +class TaskFieldFilter: + """ + Enumeration of the available filters + """ + TASK_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_TASK_ID)), rawFilters, rawFilterAnd, rawFilterField) + SESSION_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_SESSION_ID)), rawFilters, rawFilterAnd, rawFilterField) + OWNER_POD_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_OWNER_POD_ID)), rawFilters, rawFilterAnd, rawFilterField) + INITIAL_TASK_ID = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID)), rawFilters, rawFilterAnd, rawFilterField) + STATUS = StatusFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus) + CREATED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_CREATED_AT)), rawFilters, rawFilterAnd, rawFilterField) + SUBMITTED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_SUBMITTED_AT)), rawFilters, rawFilterAnd, rawFilterField) + STARTED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_STARTED_AT)), rawFilters, rawFilterAnd, rawFilterField) + ENDED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_ENDED_AT)), rawFilters, rawFilterAnd, rawFilterField) + POD_TTL = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_POD_TTL)), rawFilters, rawFilterAnd, rawFilterField) + POD_HOSTNAME = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_POD_HOSTNAME)), rawFilters, rawFilterAnd, rawFilterField) + RECEIVED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_RECEIVED_AT)), rawFilters, rawFilterAnd, rawFilterField) + ACQUIRED_AT = DateFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_ACQUIRED_AT)), rawFilters, rawFilterAnd, rawFilterField) + ERROR = StringFilter(TaskField(task_summary_field=TaskSummaryField(field=TASK_SUMMARY_ENUM_FIELD_ERROR)), rawFilters, rawFilterAnd, rawFilterField) + + MAX_RETRIES = NumberFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_MAX_RETRIES)), rawFilters, rawFilterAnd, rawFilterField) + PRIORITY = NumberFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_PRIORITY)), rawFilters, rawFilterAnd, rawFilterField) + PARTITION_ID = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_PARTITION_ID)), rawFilters, rawFilterAnd, rawFilterField) + APPLICATION_NAME = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_NAME)), rawFilters, rawFilterAnd, rawFilterField) + APPLICATION_VERSION = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION)), rawFilters, rawFilterAnd, rawFilterField) + APPLICATION_NAMESPACE = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_NAMESPACE)), rawFilters, rawFilterAnd, rawFilterField) + APPLICATION_SERVICE = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_APPLICATION_SERVICE)), rawFilters, rawFilterAnd, rawFilterField) + ENGINE_TYPE = StringFilter(TaskField(task_option_field=TaskOptionField(field=TASK_OPTION_ENUM_FIELD_ENGINE_TYPE)), rawFilters, rawFilterAnd, rawFilterField) + + @staticmethod + def task_options_key(option_key: str) -> StringFilter: + """ + Filter for the TaskOptions.Options dictionary + Args: + option_key: key in the dictionary + + Returns: + Corresponding filter + """ + return StringFilter(TaskField(task_option_generic_field=TaskOptionGenericField(field=option_key)), rawFilters, rawFilterAnd, rawFilterField) class ArmoniKTasks: @@ -25,3 +73,29 @@ def get_task(self, task_id: str) -> Task: """ task_response: GetTaskResponse = self._client.GetTask(GetTaskRequest(task_id=task_id)) return Task.from_message(task_response.task) + + def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Task]]: + """List tasks + + If the total returned exceeds the requested page size, you may want to use this function again and ask for subsequent pages. + + Args: + task_filter: Filter for the tasks to be listed + with_errors: Retrieve the error if the task had errors, defaults to false + page: page number to request, this can be useful when paginating the result, defaults to 0 + page_size: size of a page, defaults to 1000 + sort_field: field on which to sort the resulting list, defaults to the task_id + sort_direction: direction of the sort, defaults to ascending + + Returns: + A tuple containing : + - The total number of tasks for the given filter + - The obtained list of tasks + """ + request = ListTasksRequest(page=page, + page_size=page_size, + filters=cast(rawFilters, task_filter.to_disjunction().to_message()), + sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction), + with_errors=with_errors) + list_response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request) + return list_response.total, [Task.from_message(t) for t in list_response.tasks] diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index 601c66825..50859d6f3 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,3 +1,4 @@ from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability -from .enumwrapper import HealthCheckStatus, TaskStatus +from .enumwrapper import HealthCheckStatus, TaskStatus, Direction +from .filter import StringFilter, StatusFilter diff --git a/packages/python/src/armonik/common/enumwrapper.py b/packages/python/src/armonik/common/enumwrapper.py index 3c646a796..433a52e8a 100644 --- a/packages/python/src/armonik/common/enumwrapper.py +++ b/packages/python/src/armonik/common/enumwrapper.py @@ -1,18 +1,23 @@ -import enum +from __future__ import annotations -from ..protogen.common.task_status_pb2 import * +from ..protogen.common.task_status_pb2 import TaskStatus as RawStatus, _TASKSTATUS, TASK_STATUS_CANCELLED, TASK_STATUS_CANCELLING, TASK_STATUS_COMPLETED, TASK_STATUS_CREATING, TASK_STATUS_DISPATCHED, TASK_STATUS_ERROR, TASK_STATUS_PROCESSED, TASK_STATUS_PROCESSING, TASK_STATUS_SUBMITTED, TASK_STATUS_TIMEOUT, TASK_STATUS_UNSPECIFIED from ..protogen.common.worker_common_pb2 import HealthCheckReply +from ..protogen.common.sort_direction_pb2 import SORT_DIRECTION_ASC, SORT_DIRECTION_DESC # This file is necessary because the grpc types aren't considered proper types -class HealthCheckStatus(enum.Enum): +class HealthCheckStatus: NOT_SERVING = HealthCheckReply.NOT_SERVING SERVING = HealthCheckReply.SERVING UNKNOWN = HealthCheckReply.UNKNOWN -class TaskStatus(enum.Enum): +class TaskStatus: + @staticmethod + def name_from_value(status: RawStatus) -> str: + return _TASKSTATUS.values_by_number[status].name + CANCELLED = TASK_STATUS_CANCELLED CANCELLING = TASK_STATUS_CANCELLING COMPLETED = TASK_STATUS_COMPLETED @@ -24,3 +29,8 @@ class TaskStatus(enum.Enum): SUBMITTED = TASK_STATUS_SUBMITTED TIMEOUT = TASK_STATUS_TIMEOUT UNSPECIFIED = TASK_STATUS_UNSPECIFIED + + +class Direction: + ASC = SORT_DIRECTION_ASC + DESC = SORT_DIRECTION_DESC diff --git a/packages/python/src/armonik/common/filter.py b/packages/python/src/armonik/common/filter.py new file mode 100644 index 000000000..d80ccd312 --- /dev/null +++ b/packages/python/src/armonik/common/filter.py @@ -0,0 +1,333 @@ +from __future__ import annotations +from abc import abstractmethod +from typing import List, Any, Type, Optional, Dict +from google.protobuf.message import Message +import google.protobuf.timestamp_pb2 as timestamp +from ..protogen.common.filters_common_pb2 import * +import json + + +class Filter: + """ + Filter for use with ArmoniK + + Attributes: + eq_: equality raw Api operator + ne_: inequality raw Api operator + lt_: less than raw Api operator + le_: less or equal raw Api operator + gt_: greater than raw Api operator + ge_: greater or equal raw Api operator + contains_: contains raw Api operator + notcontains_: not contains raw Api operator + value_type_: expected type for the value to test against in this filter + + field: field of the filter if it's a simple filter + message_type: Api message type of the filter + inner_message_type: Api message type of the inner filter (with value and operator) + conjunction_type: Type of the conjunction for this filter + value: value to test against in this filter if it's a simple filter + operator: operator to apply for this filter if it's a simple filter + """ + eq_ = None + ne_ = None + lt_ = None + le_ = None + gt_ = None + ge_ = None + contains_ = None + notcontains_ = None + value_type_ = None + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Optional[Type[Message]], filters: Optional[List[List["Filter"]]] = None, value=None, operator=None): + self._filters: List[List["Filter"]] = [[]] if filters is None else filters + self.field = field + self.message_type = message_type + self.conjunction_type = conjunction_message_type + self.disjunction_type = disjunction_message_type + self.inner_message_type = inner_message_type + self.value = value + self.operator = operator + + def is_true_conjunction(self) -> bool: + """ + Tests whether the filter is a conjunction (logical and) + Note : This will only output true if it's an actual conjunction with multiple filters and no disjunction + """ + return self.message_type == self.conjunction_type or (len(self._filters) == 1 and len(self._filters[0]) > 1) + + def is_true_disjunction(self) -> bool: + """ + Tests whether the filter is a disjunction (logical or) + Note : This will only output true if it's an actual disjunction with multiple filters + """ + return len(self._filters) > 1 + + def to_disjunction(self) -> Filter: + """ + Converts the filter into a disjunction + + """ + if self.is_true_disjunction(): + return self + if self.is_true_conjunction(): + return Filter(None, self.disjunction_type, self.conjunction_type, self.disjunction_type, None, self._filters) + return Filter(None, self.disjunction_type, self.conjunction_type, self.disjunction_type, None, [[self]]) + + def __and__(self, other: "Filter") -> "Filter": + if not isinstance(other, Filter): + msg = f"Cannot create a conjunction between Filter and {other.__class__.__name__}" + raise Exception(msg) + if self.is_true_disjunction() or other.is_true_disjunction(): + raise Exception("Cannot make a conjunction of disjunctions") + if self.conjunction_type != other.conjunction_type: + raise Exception("Conjunction types are different") + return Filter(None, self.disjunction_type, self.conjunction_type, self.conjunction_type, None, [self.to_disjunction()._filters[0] + other.to_disjunction()._filters[0]]) + + def __mul__(self, other: Filter) -> "Filter": + return self & other + + def __or__(self, other: "Filter") -> "Filter": + if not isinstance(other, Filter): + msg = f"Cannot create a conjunction between Filter and {other.__class__.__name__}" + raise Exception(msg) + if self.disjunction_type != other.disjunction_type: + raise Exception("Disjunction types are different") + return Filter(None, self.disjunction_type, self.conjunction_type, self.disjunction_type, None, self.to_disjunction()._filters + other.to_disjunction()._filters) + + def __add__(self, other: "Filter") -> "Filter": + return self | other + + def __eq__(self, value) -> Filter: + return self._check(value, self.__class__.eq_, "==") + + def __ne__(self, value) -> Filter: + return self._check(value, self.__class__.ne_, "!=") + + def __lt__(self, value) -> Filter: + return self._check(value, self.__class__.lt_, "<") + + def __le__(self, value) -> Filter: + return self._check(value, self.__class__.le_, "<=") + + def __gt__(self, value) -> Filter: + return self._check(value, self.__class__.gt_, ">") + + def __ge__(self, value) -> Filter: + return self._check(value, self.__class__.ge_, ">=") + + def contains(self, value) -> Filter: + return self._check(value, self.__class__.contains_, "contains") + + def __invert__(self) -> Filter: + """ + Inverts the test + + Returns: + Filter with the test being inverted + """ + if self.operator is None: + if self.is_true_conjunction() or self.is_true_disjunction(): + raise Exception("Cannot invert conjunctions or disjunctions") + msg = f"Cannot invert None operator in class {self.__class__.__name__} for field {str(self.field)}" + raise Exception(msg) + if self.operator == self.__class__.eq_: + return self.__ne__(self.value) + if self.operator == self.__class__.ne_: + return self.__eq__(self.value) + if self.operator == self.__class__.lt_: + return self.__ge__(self.value) + if self.operator == self.__class__.le_: + return self.__gt__(self.value) + if self.operator == self.__class__.gt_: + return self.__le__(self.value) + if self.operator == self.__class__.ge_: + return self.__lt__(self.value) + if self.operator == self.__class__.contains_: + return self._check(self.value, self.__class__.notcontains_, "not_contains") + if self.operator == self.__class__.notcontains_: + return self.contains(self.value) + msg = f"{self.__class__.__name__} operator {str(self.operator)} for field {str(self.field)} has no inverted equivalent" + raise Exception(msg) + + def __neg__(self) -> "Filter": + return ~self + + def to_dict(self) -> Dict: + rep = {} + if self.is_true_disjunction(): + rep["or"] = [{"and": [f.to_dict() for f in conj]} for conj in self._filters] + return rep + if self.is_true_conjunction(): + rep["and"] = [f.to_dict() for f in self._filters[0]] + return rep + if len(self._filters) == 1 and len(self._filters[0]) == 1: + return self._filters[0][0].to_dict() + return {"field": str(self.field), "value": str(self.value), "operator": str(self.operator)} + + def __str__(self) -> str: + return json.dumps(self.to_dict()) + + def _verify_value(self, value): + """ + Checks if the value is of the expected type + Args: + value: Value to test + + Raises: + Exception if value is not of the expected type + + """ + if self.__class__.value_type_ is None or isinstance(value, self.__class__.value_type_): + return + msg = f"Expected value type {str(self.__class__.value_type_)} for field {str(self.field)}, got {str(type(value))} instead" + raise Exception(msg) + + def _check(self, value: Any, operator: Any, operator_str: str = "") -> "Filter": + """ + Internal function to create a new filter from the current filter with a different value and/or operator + Args: + value: Value of the new filter + operator: Operator of the new filter + operator_str: Optional string for error message clarification + + Returns: + new filter with the given value/operator + + Raises: + NotImplementedError if the given operator is not available for the given class + """ + if self.is_true_conjunction() or self.is_true_disjunction(): + raise Exception("Cannot apply operator to a disjunction or a conjunction") + self._verify_value(value) + if operator is None: + msg = f"Operator {operator_str} is not available for {self.__class__.__name__}" + raise NotImplementedError(msg) + return self.__class__(self.field, self.disjunction_type, self.conjunction_type, self.message_type, self.inner_message_type, self._filters, value, operator) + + @abstractmethod + def to_basic_message(self) -> Message: + pass + + def to_message(self) -> Message: + def to_conjunction_message(conj: List[Filter]) -> Message: + conj_raw = self.conjunction_type() + getattr(conj_raw, "and").extend([f.to_basic_message() for f in conj]) + return conj_raw + + if self.message_type == self.disjunction_type: + raw = self.to_disjunction().disjunction_type() + getattr(raw, "or").extend([to_conjunction_message(conj) for conj in self._filters]) + return raw + if self.message_type == self.conjunction_type: + return to_conjunction_message(self.to_disjunction()._filters[0]) + return self.to_basic_message() + + +class StringFilter(Filter): + """ + Filter for string comparisons + """ + eq_ = FILTER_STRING_OPERATOR_EQUAL + ne_ = FILTER_STRING_OPERATOR_NOT_EQUAL + contains_ = FILTER_STRING_OPERATOR_CONTAINS + notcontains_ = FILTER_STRING_OPERATOR_NOT_CONTAINS + value_type_ = str + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Optional[Type[Message]] = FilterString, filters: Optional[List[List["Filter"]]] = None, value=None, operator=None): + super().__init__(field, disjunction_message_type, conjunction_message_type, message_type, inner_message_type, filters, value, operator) + + def startswith(self, value: str) -> "StringFilter": + return self._check(value, FILTER_STRING_OPERATOR_STARTS_WITH, "startswith") + + def endswith(self, value: str) -> "StringFilter": + return self._check(value, FILTER_STRING_OPERATOR_ENDS_WITH, "endswith") + + def to_basic_message(self) -> Message: + return self.message_type(field=self.field, filter_string=self.inner_message_type(value=self.value, operator=self.operator)) + + def __repr__(self) -> str: + return f"{str(self.field)} {str(self.operator)} \"{str(self.value)}\"" + + +class StatusFilter(Filter): + """ + Filter for status comparison + """ + eq_ = FILTER_STATUS_OPERATOR_EQUAL + ne_ = FILTER_STATUS_OPERATOR_NOT_EQUAL + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Type[Message], filters: Optional[List[List["Filter"]]] = None, value=None, operator=None): + super().__init__(field, disjunction_message_type, conjunction_message_type, message_type, inner_message_type, filters, value, operator) + + def to_basic_message(self) -> Message: + return self.message_type(field=self.field, filter_status=self.inner_message_type(value=self.value, operator=self.operator)) + + +class DateFilter(Filter): + """Filter for timestamp comparison""" + eq_ = FILTER_DATE_OPERATOR_EQUAL + ne_ = FILTER_DATE_OPERATOR_NOT_EQUAL + lt_ = FILTER_DATE_OPERATOR_BEFORE + le_ = FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL + gt_ = FILTER_DATE_OPERATOR_AFTER + ge_ = FILTER_DATE_OPERATOR_AFTER_OR_EQUAL + value_type = timestamp.Timestamp + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Optional[Type[Message]] = FilterDate, filters: Optional[List[List["Filter"]]] = None, value=None, operator=None): + super().__init__(field, disjunction_message_type, conjunction_message_type, message_type, inner_message_type, filters, value, operator) + + def to_basic_message(self) -> Message: + return self.message_type(field=self.field, filter_date=self.inner_message_type(value=self.value, operator=self.operator)) + + +class NumberFilter(Filter): + """Filter for int comparison""" + eq_ = FILTER_NUMBER_OPERATOR_EQUAL + ne_ = FILTER_NUMBER_OPERATOR_NOT_EQUAL + lt_ = FILTER_NUMBER_OPERATOR_LESS_THAN + le_ = FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL + gt_ = FILTER_NUMBER_OPERATOR_GREATER_THAN + ge_ = FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL + value_type_ = int + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Optional[Type[Message]] = FilterNumber, filters: Optional[List[List["Filter"]]] = None, value=None, operator=None): + super().__init__(field, disjunction_message_type, conjunction_message_type, message_type, inner_message_type, filters, value, operator) + + def to_basic_message(self) -> Message: + return self.message_type(field=self.field, filter_number=self.inner_message_type(value=self.value, operator=self.operator)) + + +class BooleanFilter(Filter): + """ + Filter for boolean comparison + """ + eq_ = FILTER_BOOLEAN_OPERATOR_IS + value_type_ = bool + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Optional[Type[Message]] = FilterBoolean, filters: Optional[List[List["Filter"]]] = None, value=True, operator=FILTER_BOOLEAN_OPERATOR_IS): + super().__init__(field, disjunction_message_type, conjunction_message_type, message_type, inner_message_type, filters, value, operator) + + def __ne__(self, value: bool) -> "BooleanFilter": + return self.__eq__(not value) + + def __invert__(self) -> "BooleanFilter": + return self.__eq__(not self.value) + + def to_basic_message(self) -> Message: + return self.message_type(field=self.field, filter_boolean=self.inner_message_type(value=self.value, operator=self.operator)) + + +class ArrayFilter(Filter): + """ + Filter for array comparisons + """ + contains_ = FILTER_ARRAY_OPERATOR_CONTAINS + notcontains_ = FILTER_ARRAY_OPERATOR_NOT_CONTAINS + value_type_ = str + + def __init__(self, field: Optional[Message], disjunction_message_type: Type[Message], conjunction_message_type: Type[Message], message_type: Type[Message], inner_message_type: Optional[Type[Message]] = FilterArray, filters: Optional[List[List["Filter"]]] = None, value=None, operator=None): + super().__init__(field, disjunction_message_type, conjunction_message_type, message_type, inner_message_type, filters, value, operator) + + def to_basic_message(self) -> Message: + return self.message_type(field=self.field, filter_array=self.inner_message_type(value=self.value, operator=self.operator)) diff --git a/packages/python/src/armonik/common/helpers.py b/packages/python/src/armonik/common/helpers.py index cbc08bba1..e174e2f42 100644 --- a/packages/python/src/armonik/common/helpers.py +++ b/packages/python/src/armonik/common/helpers.py @@ -1,3 +1,4 @@ +from __future__ import annotations from datetime import timedelta, datetime, timezone from typing import List, Optional @@ -37,9 +38,9 @@ def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[ if task_ids: task_filter.task.ids.extend(task_ids) if included_statuses: - task_filter.included.statuses.extend([t.value for t in included_statuses]) + task_filter.included.statuses.extend(included_statuses) if excluded_statuses: - task_filter.excluded.statuses.extend([t.value for t in excluded_statuses]) + task_filter.excluded.statuses.extend(excluded_statuses) return task_filter diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 02d6f1f7b..453cb5c56 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -1,3 +1,4 @@ +from __future__ import annotations from dataclasses import dataclass, field from datetime import timedelta, datetime from typing import Optional, List, Dict @@ -5,6 +6,7 @@ from ..protogen.common.tasks_common_pb2 import TaskDetailed from .helpers import duration_to_timedelta, timedelta_to_duration, timestamp_to_datetime from ..protogen.common.objects_pb2 import Empty, Output as WorkerOutput, TaskOptions as RawTaskOptions +from ..protogen.common.task_status_pb2 import TaskStatus as RawTaskStatus from .enumwrapper import TaskStatus @@ -83,7 +85,7 @@ class Task: data_dependencies: List[str] = field(default_factory=list) expected_output_ids: List[str] = field(default_factory=list) retry_of_ids: List[str] = field(default_factory=list) - status: TaskStatus = TaskStatus.UNSPECIFIED + status: RawTaskStatus = TaskStatus.UNSPECIFIED status_message: Optional[str] = None options: Optional[TaskOptions] = None created_at: Optional[datetime] = None @@ -109,7 +111,7 @@ def refresh(self, task_client) -> None: self.data_dependencies = result.data_dependencies self.expected_output_ids = result.expected_output_ids self.retry_of_ids = result.retry_of_ids - self.status = TaskStatus(result.status) + self.status = result.status self.status_message = result.status_message self.options = result.options self.created_at = result.created_at @@ -133,7 +135,7 @@ def from_message(cls, task_raw: TaskDetailed) -> "Task": data_dependencies=list(task_raw.data_dependencies), expected_output_ids=list(task_raw.expected_output_ids), retry_of_ids=list(task_raw.retry_of_ids), - status=TaskStatus(task_raw.status), + status=task_raw.status, status_message=task_raw.status_message, options=TaskOptions.from_message(task_raw.options), created_at=timestamp_to_datetime(task_raw.created_at), diff --git a/packages/python/src/armonik/worker/seqlogger.py b/packages/python/src/armonik/worker/seqlogger.py index 79637bce1..4d5d48ec7 100644 --- a/packages/python/src/armonik/worker/seqlogger.py +++ b/packages/python/src/armonik/worker/seqlogger.py @@ -1,3 +1,4 @@ +from __future__ import annotations import logging import sys import json diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index ec70a27bd..9d8ade962 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,4 +1,4 @@ -import uuid +from __future__ import annotations from typing import Optional, Dict, List, Tuple, Union, cast from ..common import TaskOptions, TaskDefinition, Task @@ -141,7 +141,7 @@ def result_stream(): result_reply = self._client.SendResult(result_stream()) if result_reply.WhichOneof("type") == "error": raise Exception(f"Cannot send result id={key}") - + def get_results_ids(self, names : List[str]) -> Dict[str, str]: return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=self.session_id, communication_token=self.token))).results} diff --git a/packages/python/src/armonik/worker/worker.py b/packages/python/src/armonik/worker/worker.py index 9ac1a6884..19db04f38 100644 --- a/packages/python/src/armonik/worker/worker.py +++ b/packages/python/src/armonik/worker/worker.py @@ -1,3 +1,4 @@ +from __future__ import annotations import traceback from concurrent import futures from typing import Callable, Union @@ -15,7 +16,7 @@ class ArmoniKWorker(WorkerServicer): - def __init__(self, agent_channel: Channel, processing_function: Callable[[TaskHandler], Output], health_check: Callable[[], HealthCheckStatus] = lambda: HealthCheckStatus.SERVING, logger=ClefLogger.getLogger("ArmoniKWorker")): + def __init__(self, agent_channel: Channel, processing_function: Callable[[TaskHandler], Output], health_check: Callable[[], HealthCheckReply.ServingStatus] = lambda: HealthCheckStatus.SERVING, logger=ClefLogger.getLogger("ArmoniKWorker")): """Creates a worker for ArmoniK Args: @@ -54,4 +55,4 @@ def Process(self, request_iterator, context) -> Union[ProcessReply, None]: self._logger.exception(f"Failed task {''.join(traceback.format_exception(type(e) ,e, e.__traceback__))}", exc_info=e) def HealthCheck(self, request: Empty, context) -> HealthCheckReply: - return HealthCheckReply(status=self.health_check().value) + return HealthCheckReply(status=self.health_check()) diff --git a/packages/python/tests/filters_test.py b/packages/python/tests/filters_test.py new file mode 100644 index 000000000..b67f7341f --- /dev/null +++ b/packages/python/tests/filters_test.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +from typing import Type + +import pytest + +from dataclasses import dataclass +from armonik.common.filter import Filter, StringFilter, BooleanFilter, NumberFilter +from armonik.protogen.common.filters_common_pb2 import FilterBoolean +from google.protobuf.message import Message + + +@dataclass +class DummyMessage(Message): + pass + + +@dataclass +class DummyMessageAnd(Message): + pass + + +@dataclass +class Field(Message): + pass + + +@pytest.mark.parametrize("filt,inverted", [ + (StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) == "Test", StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) != "Test"), + (StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage).contains("Test"), ~(StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage).contains("Test"))), + (BooleanFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage), BooleanFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage, FilterBoolean, None, False)), + (NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) > 0, NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) <= 0), + (NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) >= 0, NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) < 0), + (NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) < 0, NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) >= 0), + (NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) <= 0, NumberFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage) > 0) + +]) +def test_inversion(filt: Filter, inverted: Filter): + assert filt.operator != inverted.operator or filt.value == (not inverted.value) # In case of BooleanFilter, the value is inverted, not the operator + assert (~filt).operator == inverted.operator and (~filt).value == inverted.value + assert filt.operator == (~(~filt)).operator and filt.value == (~(~filt)).value + + +@pytest.mark.parametrize("filt", [ + (StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage).startswith("Test")), + (StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage).endswith("Test")), + (StringFilter(Field(), DummyMessage, DummyMessageAnd, DummyMessage)) # No op +]) +def test_inversion_raises(filt: Filter): + with pytest.raises(Exception): + test = ~filt + print(test) diff --git a/packages/python/tests/submitter_test.py b/packages/python/tests/submitter_test.py index b57de1c1a..c849efdc0 100644 --- a/packages/python/tests/submitter_test.py +++ b/packages/python/tests/submitter_test.py @@ -195,8 +195,8 @@ def test_armonik_submitter_should_list_tasks(session_ids, task_ids, included_sta assert inner.task_filter is not None assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.task_filter.session.ids))) assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.task_filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]].value, enumerate(inner.task_filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]].value, enumerate(inner.task_filter.excluded.statuses))) + assert all(map(lambda x: x[1] == included_statuses[x[0]], enumerate(inner.task_filter.included.statuses))) + assert all(map(lambda x: x[1] == excluded_statuses[x[0]], enumerate(inner.task_filter.excluded.statuses))) else: with pytest.raises(ValueError): _ = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, @@ -223,7 +223,7 @@ def test_armonik_submitter_should_get_status(): [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True)), ResultReply(result=DataChunk(data="payload".encode("utf-8")))], [ResultReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TaskStatus.ERROR.value, detail="TestError")]))], + error=TaskError(task_id="TaskId", errors=[Error(task_status=TaskStatus.ERROR, detail="TestError")]))], ] get_result_should_succeed = [ @@ -300,9 +300,9 @@ def test_armonik_submitter_wait_completion(session_ids, task_ids, included_statu assert inner.wait_request is not None assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.wait_request.filter.session.ids))) assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.wait_request.filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]].value, + assert all(map(lambda x: x[1] == included_statuses[x[0]], enumerate(inner.wait_request.filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]].value, + assert all(map(lambda x: x[1] == excluded_statuses[x[0]], enumerate(inner.wait_request.filter.excluded.statuses))) assert not inner.wait_request.stop_on_first_task_error assert not inner.wait_request.stop_on_first_task_cancellation diff --git a/packages/python/tests/tasks_test.py b/packages/python/tests/tasks_test.py index a47ed98b9..752b2ac63 100644 --- a/packages/python/tests/tasks_test.py +++ b/packages/python/tests/tasks_test.py @@ -1,12 +1,22 @@ #!/usr/bin/env python3 -from typing import Optional +import dataclasses +from typing import Optional, List, Any, Union, Dict, Collection +from google.protobuf.timestamp_pb2 import Timestamp from datetime import datetime + +import pytest + from .common import DummyChannel from armonik.client import ArmoniKTasks +from armonik.client.tasks import TaskFieldFilter from armonik.common import TaskStatus, datetime_to_timestamp, Task +from armonik.common.filter import StringFilter, Filter from armonik.protogen.client.tasks_service_pb2_grpc import TasksStub from armonik.protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, TaskDetailed +from armonik.protogen.common.tasks_filters_pb2 import Filters, FilterField +from armonik.protogen.common.filters_common_pb2 import * +from armonik.protogen.common.tasks_fields_pb2 import * from .submitter_test import default_task_option @@ -20,9 +30,9 @@ def GetTask(self, request: GetTaskRequest) -> GetTaskResponse: self.task_request = request raw = TaskDetailed(id="TaskId", session_id="SessionId", owner_pod_id="PodId", parent_task_ids=["ParentTaskId"], data_dependencies=["DD"], expected_output_ids=["EOK"], retry_of_ids=["RetryId"], - status=TaskStatus.COMPLETED.value, status_message="Message", + status=TaskStatus.COMPLETED, status_message="Message", options=default_task_option.to_message(), - created_at=datetime_to_timestamp(datetime.now()), + created_at=datetime_to_timestamp(datetime.now()), started_at=datetime_to_timestamp(datetime.now()), submitted_at=datetime_to_timestamp(datetime.now()), ended_at=datetime_to_timestamp(datetime.now()), pod_ttl=datetime_to_timestamp(datetime.now()), @@ -61,3 +71,211 @@ def test_task_refresh(): assert current.parent_task_ids == ["ParentTaskId"] assert current.output assert current.output.success + + +def test_task_filters(): + filt: StringFilter = TaskFieldFilter.TASK_ID == "TaskId" + message = filt.to_message() + assert isinstance(message, FilterField) + assert message.field.WhichOneof("field") == "task_summary_field" + assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID + assert message.filter_string.value == "TaskId" + assert message.filter_string.operator == FILTER_STRING_OPERATOR_EQUAL + + filt: StringFilter = TaskFieldFilter.TASK_ID != "TaskId" + message = filt.to_message() + assert isinstance(message, FilterField) + assert message.field.WhichOneof("field") == "task_summary_field" + assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID + assert message.filter_string.value == "TaskId" + assert message.filter_string.operator == FILTER_STRING_OPERATOR_NOT_EQUAL + + +@dataclasses.dataclass +class SimpleFieldFilter: + field: Any + value: Any + operator: Any + + +@pytest.mark.parametrize("filt,n_or,n_and,filters", [ + ( + (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), + 1, [1], + [ + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) + ] + ), + ( + (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), + 1, [2], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) + ] + ), + ( + (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), + 2, [1, 2], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), + SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) + ] + ), + ( + (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), + 2, [2, 3], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), + ] + ), + ( + (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), + 2, [2, 3], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), + ] + ) +]) +def test_filter_combination(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter]): + filt = filt.to_disjunction() + assert len(filt._filters) == n_or + sorted_n_and = sorted(n_and) + sorted_actual = sorted([len(f) for f in filt._filters]) + assert len(sorted_n_and) == len(sorted_actual) + assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) + for f in filt._filters: + for ff in f: + field_value = getattr(ff.field, ff.field.WhichOneof("field")).field + for i, expected in enumerate(filters): + if expected.field == field_value and expected.value == ff.value and expected.operator == ff.operator: + filters.pop(i) + break + else: + print(f"Could not find {str(ff)}") + assert False + assert len(filters) == 0 + + +def test_name_from_value(): + assert TaskStatus.name_from_value(TaskStatus.COMPLETED) == "TASK_STATUS_COMPLETED" + + +class BasicFilterAnd: + + def __setattr__(self, key, value): + self.__dict__[key] = value + + def __getattr__(self, item): + return self.__dict__[item] + + +@pytest.mark.parametrize("filt,n_or,n_and,filters,expected_type", [ + ( + (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), + 1, [1], + [ + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) + ], + 0 + ), + ( + (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), + 1, [2], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) + ], + 1 + ), + ( + (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), + 2, [1, 2], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), + SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) + ], + 2 + ), + ( + (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), + 2, [2, 3], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), + ], + 2 + ), + ( + (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), + 2, [2, 3], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), + ], + 2 + ), + ( + (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))) + (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), + 4, [2, 3, 2, 1], + [ + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), + SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), + SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), + SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) + ], + 2 + ) +]) +def test_taskfilter_to_message(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter], expected_type: int): + print(filt) + message = filt.to_message() + conjs: Collection = [] + if expected_type == 2: # Disjunction + conjs: Collection = getattr(message, "or") + assert len(conjs) == n_or + sorted_n_and = sorted(n_and) + sorted_actual = sorted([len(getattr(f, "and")) for f in conjs]) + assert len(sorted_n_and) == len(sorted_actual) + assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) + + if expected_type == 1: # Conjunction + conjs: Collection = [message] + + if expected_type == 0: # Simple filter + m = BasicFilterAnd() + setattr(m, "and", [message]) + conjs: Collection = [m] + + for conj in conjs: + basics = getattr(conj, "and") + for f in basics: + field_value = getattr(f.field, f.field.WhichOneof("field")).field + for i, expected in enumerate(filters): + if expected.field == field_value and expected.value == getattr(f, f.WhichOneof("value_condition")).value and expected.operator == getattr(f, f.WhichOneof("value_condition")).operator: + filters.pop(i) + break + else: + print(f"Could not find {str(f)}") + assert False + assert len(filters) == 0