diff --git a/.editorconfig b/.editorconfig index 2315b0026..5ed1b0dfd 100644 --- a/.editorconfig +++ b/.editorconfig @@ -8,6 +8,6 @@ end_of_line = lf insert_final_newline = true trim_trailing_whitespace = true -[{packages/python/** , examples/python/**}] +[{packages/python/**,examples/python/**}] indent_size = 4 max_line_length = off \ No newline at end of file diff --git a/examples/python/common.py b/examples/python/common.py index e69de29bb..6f246f84e 100644 --- a/examples/python/common.py +++ b/examples/python/common.py @@ -0,0 +1,27 @@ +from typing import List, Union +import json + + +class Payload: + def __init__(self, values: List[Union[float, str]], subtask_threshold=2): + self.values = values + self.subtask_threshold = subtask_threshold + + def serialize(self) -> bytes: + return json.dumps({"values": self.values, "subtask_threshold": self.subtask_threshold}).encode("utf-8") + + @classmethod + def deserialize(cls, payload: bytes) -> "Payload": + return cls(**json.loads(payload.decode("utf-8"))) + + +class Result: + def __init__(self, value: float): + self.value = value + + def serialize(self) -> bytes: + return json.dumps({"value": self.value}).encode("utf-8") + + @classmethod + def deserialize(cls, payload: bytes) -> "Result": + return cls(**json.loads(payload.decode("utf-8"))) diff --git a/examples/python/worker.py b/examples/python/worker.py index 3c54771e1..91678b295 100644 --- a/examples/python/worker.py +++ b/examples/python/worker.py @@ -1,11 +1,75 @@ +import logging +import os + +import grpc from armonik.worker import ArmoniKWorker, TaskHandler -from armonik.common import Output +from armonik.worker import get_worker_logger +from armonik.common import Output, TaskDefinition +from typing import List + +from common import Payload, Result + +logger = get_worker_logger("ArmoniKWorker", level=logging.INFO) -# Actual computation +# Task processing def processor(task_handler: TaskHandler) -> Output: + payload = Payload.deserialize(task_handler.payload) + # No values + if len(payload.values) == 0: + if len(task_handler.expected_results) > 0: + task_handler.send_result(task_handler.expected_results[0], Result(0.0).serialize()) + logger.info("No values") + return Output() + + if isinstance(payload.values[0], str): + # Aggregation task + results = [Result.deserialize(task_handler.data_dependencies[r]).value for r in payload.values] + task_handler.send_result(task_handler.expected_results[0], Result(aggregate(results)).serialize()) + logger.info(f"Aggregated {len(results)} values") + return Output() + + if len(payload.values) <= 1 or len(payload.values) <= payload.subtask_threshold: + # Compute + task_handler.send_result(task_handler.expected_results[0], Result(aggregate(payload.values)).serialize()) + logger.info(f"Computed {len(payload.values)} values") + return Output() + + # Subtasking + pivot = len(payload.values) // 2 + lower = payload.values[:pivot] + upper = payload.values[pivot:] + subtasks = [] + for vals in [lower, upper]: + new_payload = Payload(values=vals, subtask_threshold=payload.subtask_threshold).serialize() + subtasks.append(TaskDefinition(payload=new_payload, expected_outputs=[task_handler.request_output_id()])) + aggregate_dependencies = [s.expected_outputs[0] for s in subtasks] + subtasks.append(TaskDefinition(Payload(values=aggregate_dependencies).serialize(), expected_outputs=task_handler.expected_results, data_dependencies=aggregate_dependencies)) + if len(subtasks) > 0: + submitted, errors = task_handler.create_tasks(subtasks) + if len(errors) > 0: + message = f"Errors while submitting subtasks : {', '.join(errors)}" + logger.error(message) + return Output(message) + logger.info(f"Submitted {len(submitted)} subtasks") return Output() +def aggregate(values: List[float]) -> float: + return sum(values) + + def main(): + worker_scheme = "unix://" if os.getenv("ComputePlane__WorkerChannel__SocketType", "unixdomainsocket") == "unixdomainsocket" else "http://" + agent_scheme = "unix://" if os.getenv("ComputePlane__AgentChannel__SocketType", "unixdomainsocket") == "unixdomainsocket" else "http://" + worker_endpoint = worker_scheme+os.getenv("ComputePlane__WorkerChannel__Address", "/cache/armonik_worker.sock") + agent_endpoint = agent_scheme+os.getenv("ComputePlane__AgentChannel__Address", "/cache/armonik_agent.sock") + logger.info("Worker Started") + with grpc.insecure_channel(agent_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, processor, logger=logger) + logger.info("Worker Connected") + worker.start(worker_endpoint) + +if __name__ == "__main__": + main() diff --git a/packages/python/proto2python.sh b/packages/python/proto2python.sh index d10ecc8ec..0c9fff1e0 100644 --- a/packages/python/proto2python.sh +++ b/packages/python/proto2python.sh @@ -31,7 +31,7 @@ mkdir -p $ARMONIK_WORKER $ARMONIK_CLIENT $ARMONIK_COMMON $PACKAGE_PATH python -m pip install --upgrade pip python -m venv $PYTHON_VENV source $PYTHON_VENV/bin/activate -python -m pip install build grpcio grpcio-tools click +python -m pip install build grpcio grpcio-tools click seqlog unset proto_files for proto in ${armonik_worker_files[@]}; do diff --git a/packages/python/pyproject.toml b/packages/python/pyproject.toml index db14dd4b0..0343998cd 100644 --- a/packages/python/pyproject.toml +++ b/packages/python/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "grpcio", "grpcio-tools", + "seqlog" ] [project.urls] diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index 81a0fa1ac..95f9a083e 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -1,2 +1,2 @@ from .submitter import ArmoniKSubmitter -from .tasks import ArmoniKTasks \ No newline at end of file +from .tasks import ArmoniKTasks diff --git a/packages/python/src/armonik/client/submitter.py b/packages/python/src/armonik/client/submitter.py index 0cf1e12b5..2874a6a52 100644 --- a/packages/python/src/armonik/client/submitter.py +++ b/packages/python/src/armonik/client/submitter.py @@ -3,8 +3,7 @@ from grpc import Channel -from ..common.helpers import get_task_filter -from ..common.objects import Configuration, TaskOptions, TaskStatus, TaskDefinition, Task +from ..common import get_task_filter, TaskOptions, TaskDefinition, Task, TaskStatus from ..protogen.client.submitter_service_pb2_grpc import SubmitterStub from ..protogen.common.objects_pb2 import Empty, TaskRequest, ResultRequest, DataChunk, InitTaskRequest, \ TaskRequestHeader, Configuration @@ -43,7 +42,7 @@ def __init__(self, grpc_channel: Channel): self._client = SubmitterStub(grpc_channel) def get_service_configuration(self) -> Configuration: - return Configuration(self._client.GetServiceConfiguration(Empty())) + return self._client.GetServiceConfiguration(Empty()) def create_session(self, default_task_options: TaskOptions, partition_ids: Optional[List[str]] = None) -> str: if partition_ids is None: @@ -96,7 +95,7 @@ def get_task_status(self, task_ids: List[str]) -> Dict[str, TaskStatus]: request = GetTaskStatusRequest() request.task_ids.extend(task_ids) reply = self._client.GetTaskStatus(request) - return dict([(s.task_id, TaskStatus(s.status)) for s in reply.id_statuses]) + return dict([(s.task_id, s.status) for s in reply.id_statuses]) def wait_for_completion(self, session_ids: Optional[List[str]] = None, @@ -105,7 +104,7 @@ def wait_for_completion(self, excluded_statuses: Optional[List[TaskStatus]] = None, stop_on_first_task_error: bool = False, stop_on_first_task_cancellation: bool = False) -> Dict[TaskStatus, int]: - return dict([(TaskStatus(sc.status), sc.count) for sc in self._client.WaitForCompletion( + return dict([(sc.status, sc.count) for sc in self._client.WaitForCompletion( WaitRequest(filter=get_task_filter(session_ids, task_ids, included_statuses, excluded_statuses), stop_on_first_task_error=stop_on_first_task_error, stop_on_first_task_cancellation=stop_on_first_task_cancellation)).values]) diff --git a/packages/python/src/armonik/client/tasks.py b/packages/python/src/armonik/client/tasks.py index 1f9af7888..f40d3dd83 100644 --- a/packages/python/src/armonik/client/tasks.py +++ b/packages/python/src/armonik/client/tasks.py @@ -1,6 +1,6 @@ from grpc import Channel -from ..common.objects import Task, TaskStatus, TaskOptions +from ..common import Task, TaskOptions from ..protogen.client.tasks_service_pb2_grpc import TasksStub from ..protogen.common.tasks_common_pb2 import GetTaskRequest @@ -19,7 +19,7 @@ def get_task(self, task_id: str) -> Task: task.parent_task_ids.extend(raw.parent_task_ids) task.data_dependencies.extend(raw.data_dependencies) task.expected_output_ids.extend(raw.expected_output_ids) - task.status = TaskStatus(raw.status) + task.status = raw.status task.status_message = raw.status_message task.options = TaskOptions.from_message(raw.options) task.retry_of_ids.extend(raw.retry_of_ids) diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index b0c074061..5b38e6bcf 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,2 +1,3 @@ -from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration -from .objects import Task, TaskDefinition, TaskOptions, Output \ No newline at end of file +from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter +from .objects import Task, TaskDefinition, TaskOptions, Output +from .enumwrapper import HealthCheckStatus, TaskStatus diff --git a/packages/python/src/armonik/common/enumwrapper.py b/packages/python/src/armonik/common/enumwrapper.py new file mode 100644 index 000000000..11b634fc6 --- /dev/null +++ b/packages/python/src/armonik/common/enumwrapper.py @@ -0,0 +1,26 @@ +import enum + +from ..protogen.common.task_status_pb2 import * +from ..protogen.common.worker_common_pb2 import HealthCheckReply + +# This file is necessary because the grpc values don't have the proper type + + +class HealthCheckStatus(enum.Enum): + NOT_SERVING = HealthCheckReply.NOT_SERVING + SERVING = HealthCheckReply.SERVING + UNKNOWN = HealthCheckReply.UNKNOWN + + +class TaskStatus(enum.Enum): + CANCELLED = TASK_STATUS_CANCELLED + CANCELLING = TASK_STATUS_CANCELLING + COMPLETED = TASK_STATUS_COMPLETED + CREATING = TASK_STATUS_CREATING + DISPATCHED = TASK_STATUS_DISPATCHED + ERROR = TASK_STATUS_ERROR + PROCESSED = TASK_STATUS_PROCESSED + PROCESSING = TASK_STATUS_PROCESSING + SUBMITTED = TASK_STATUS_SUBMITTED + TIMEOUT = TASK_STATUS_TIMEOUT + UNSPECIFIED = TASK_STATUS_UNSPECIFIED diff --git a/packages/python/src/armonik/common/helpers.py b/packages/python/src/armonik/common/helpers.py index 6d1276276..774e61fad 100644 --- a/packages/python/src/armonik/common/helpers.py +++ b/packages/python/src/armonik/common/helpers.py @@ -5,8 +5,8 @@ import google.protobuf.duration_pb2 as duration import google.protobuf.timestamp_pb2 as timestamp -from .objects import TaskStatus from ..protogen.common.submitter_common_pb2 import TaskFilter +from .enumwrapper import TaskStatus def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None, diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index f5925b9c0..946d50444 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -2,11 +2,10 @@ from datetime import timedelta, datetime from typing import Optional, List -from client.tasks import ArmoniKTasks -from protogen.common.tasks_common_pb2 import TaskRaw +from ..protogen.common.tasks_common_pb2 import TaskRaw from .helpers import duration_to_timedelta, timedelta_to_duration, timestamp_to_datetime from ..protogen.common.objects_pb2 import Empty, Output as WorkerOutput -from ..protogen.common.task_status_pb2 import * +from .enumwrapper import TaskStatus @dataclass() @@ -84,7 +83,7 @@ class Task: data_dependencies: Optional[List[str]] = None expected_output_ids: Optional[List[str]] = None retry_of_ids: Optional[List[str]] = None - status: TaskStatus = TaskStatus.TASK_STATUS_UNSPECIFIED + status: TaskStatus = TaskStatus.UNSPECIFIED status_message: Optional[str] = None options: Optional[TaskOptions] = None created_at: Optional[datetime] = None @@ -97,7 +96,7 @@ class Task: received_at: Optional[datetime] = None acquired_at: Optional[datetime] = None - def refresh(self, task_client: ArmoniKTasks) -> None: + def refresh(self, task_client) -> None: result = task_client.get_task(self.id) self.session_id = result.session_id self.owner_pod_id = result.owner_pod_id @@ -128,7 +127,7 @@ def from_message(cls, task_raw: TaskRaw) -> "Task": data_dependencies=list(task_raw.data_dependencies), expected_output_ids=list(task_raw.expected_output_ids), retry_of_ids=list(task_raw.retry_of_ids), - status=task_raw.status, + status=TaskStatus(task_raw.status), status_message=task_raw.status_message, options=task_raw.options, created_at=timestamp_to_datetime(task_raw.created_at), diff --git a/packages/python/src/armonik/worker/__init__.py b/packages/python/src/armonik/worker/__init__.py index 92f804cdd..b3fef07e2 100644 --- a/packages/python/src/armonik/worker/__init__.py +++ b/packages/python/src/armonik/worker/__init__.py @@ -1,2 +1,3 @@ from .worker import ArmoniKWorker -from .taskhandler import TaskHandler \ No newline at end of file +from .taskhandler import TaskHandler +from .seqlogger import get_worker_logger diff --git a/packages/python/src/armonik/worker/seqlogger.py b/packages/python/src/armonik/worker/seqlogger.py new file mode 100644 index 000000000..a11b62f3d --- /dev/null +++ b/packages/python/src/armonik/worker/seqlogger.py @@ -0,0 +1,18 @@ +from logging import Logger, getLogger, Formatter +from seqlog import ConsoleStructuredLogHandler +from typing import Dict + +_worker_loggers: Dict[str, Logger] = {} + + +def get_worker_logger(name: str, level: int) -> Logger: + if name in _worker_loggers: + return _worker_loggers[name] + logger = getLogger(name) + logger.handlers.clear() + handler = ConsoleStructuredLogHandler() + handler.setFormatter(Formatter(style="{")) + logger.addHandler(handler) + logger.setLevel(level) + _worker_loggers[name] = logger + return _worker_loggers[name] diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index d62486068..4f9187374 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,9 +1,9 @@ import uuid from typing import Optional, Dict, List, Tuple, Union -from ..common.objects import TaskOptions, Configuration, TaskDefinition, Task +from ..common import TaskOptions, TaskDefinition, Task from ..protogen.common.agent_common_pb2 import Result, CreateTaskRequest -from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader +from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration from ..protogen.worker.agent_service_pb2_grpc import AgentStub diff --git a/packages/python/src/armonik/worker/worker.py b/packages/python/src/armonik/worker/worker.py index 1be489b39..445a8f321 100644 --- a/packages/python/src/armonik/worker/worker.py +++ b/packages/python/src/armonik/worker/worker.py @@ -1,3 +1,4 @@ +import logging import traceback from concurrent import futures from typing import Callable @@ -5,18 +6,20 @@ import grpc from grpc import Channel -from ..common.objects import Output +from .seqlogger import get_worker_logger +from ..common import Output, HealthCheckStatus from ..protogen.common.worker_common_pb2 import ProcessReply, HealthCheckReply from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.worker.worker_service_pb2_grpc import WorkerServicer, add_WorkerServicer_to_server -from ..worker.taskhandler import TaskHandler +from .taskhandler import TaskHandler class ArmoniKWorker(WorkerServicer): - def __init__(self, agent_channel: Channel, processing_function: Callable[[TaskHandler], Output], health_check: Callable[[], HealthCheckReply.ServingStatus] = lambda: HealthCheckReply.SERVING): + def __init__(self, agent_channel: Channel, processing_function: Callable[[TaskHandler], Output], health_check: Callable[[], HealthCheckStatus] = lambda: HealthCheckStatus.SERVING, logger=get_worker_logger("ArmoniKWorker", logging.INFO)): self.health_check = health_check self.processing_function = processing_function self._client = AgentStub(agent_channel) + self._logger = logger def start(self, endpoint: str): server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) @@ -30,7 +33,7 @@ def Process(self, request_iterator, context) -> ProcessReply: task_handler = TaskHandler.create(request_iterator, self._client) return ProcessReply(output=self.processing_function(task_handler).to_message()) except Exception as e: - print(f"Failed task {''.join(traceback.format_exception(e))}") + self._logger.exception(f"Failed task {''.join(traceback.format_exception(e))}", exc_info=e) def HealthCheck(self, request, context) -> HealthCheckReply: - return HealthCheckReply(status=self.health_check()) + return HealthCheckReply(status=self.health_check().value)