Skip to content

Commit

Permalink
Fixed worker
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrasseur-aneo committed Mar 21, 2023
1 parent 9b051fe commit 824e870
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true

[{packages/python/** , examples/python/**}]
[{packages/python/**,examples/python/**}]
indent_size = 4
max_line_length = off
27 changes: 27 additions & 0 deletions examples/python/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import List, Union
import json


class Payload:
def __init__(self, values: List[Union[float, str]], subtask_threshold=2):
self.values = values
self.subtask_threshold = subtask_threshold

def serialize(self) -> bytes:
return json.dumps({"values": self.values, "subtask_threshold": self.subtask_threshold}).encode("utf-8")

@classmethod
def deserialize(cls, payload: bytes) -> "Payload":
return cls(**json.loads(payload.decode("utf-8")))


class Result:
def __init__(self, value: float):
self.value = value

def serialize(self) -> bytes:
return json.dumps({"value": self.value}).encode("utf-8")

@classmethod
def deserialize(cls, payload: bytes) -> "Result":
return cls(**json.loads(payload.decode("utf-8")))
68 changes: 66 additions & 2 deletions examples/python/worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,75 @@
import logging
import os

import grpc
from armonik.worker import ArmoniKWorker, TaskHandler
from armonik.common import Output
from armonik.worker import get_worker_logger
from armonik.common import Output, TaskDefinition
from typing import List

from common import Payload, Result

logger = get_worker_logger("ArmoniKWorker", level=logging.INFO)


# Actual computation
# Task processing
def processor(task_handler: TaskHandler) -> Output:
payload = Payload.deserialize(task_handler.payload)
# No values
if len(payload.values) == 0:
if len(task_handler.expected_results) > 0:
task_handler.send_result(task_handler.expected_results[0], Result(0.0).serialize())
logger.info("No values")
return Output()

if isinstance(payload.values[0], str):
# Aggregation task
results = [Result.deserialize(task_handler.data_dependencies[r]).value for r in payload.values]
task_handler.send_result(task_handler.expected_results[0], Result(aggregate(results)).serialize())
logger.info(f"Aggregated {len(results)} values")
return Output()

if len(payload.values) <= 1 or len(payload.values) <= payload.subtask_threshold:
# Compute
task_handler.send_result(task_handler.expected_results[0], Result(aggregate(payload.values)).serialize())
logger.info(f"Computed {len(payload.values)} values")
return Output()

# Subtasking
pivot = len(payload.values) // 2
lower = payload.values[:pivot]
upper = payload.values[pivot:]
subtasks = []
for vals in [lower, upper]:
new_payload = Payload(values=vals, subtask_threshold=payload.subtask_threshold).serialize()
subtasks.append(TaskDefinition(payload=new_payload, expected_outputs=[task_handler.request_output_id()]))
aggregate_dependencies = [s.expected_outputs[0] for s in subtasks]
subtasks.append(TaskDefinition(Payload(values=aggregate_dependencies).serialize(), expected_outputs=task_handler.expected_results, data_dependencies=aggregate_dependencies))
if len(subtasks) > 0:
submitted, errors = task_handler.create_tasks(subtasks)
if len(errors) > 0:
message = f"Errors while submitting subtasks : {', '.join(errors)}"
logger.error(message)
return Output(message)
logger.info(f"Submitted {len(submitted)} subtasks")
return Output()


def aggregate(values: List[float]) -> float:
return sum(values)


def main():
worker_scheme = "unix://" if os.getenv("ComputePlane__WorkerChannel__SocketType", "unixdomainsocket") == "unixdomainsocket" else "http://"
agent_scheme = "unix://" if os.getenv("ComputePlane__AgentChannel__SocketType", "unixdomainsocket") == "unixdomainsocket" else "http://"
worker_endpoint = worker_scheme+os.getenv("ComputePlane__WorkerChannel__Address", "/cache/armonik_worker.sock")
agent_endpoint = agent_scheme+os.getenv("ComputePlane__AgentChannel__Address", "/cache/armonik_agent.sock")
logger.info("Worker Started")
with grpc.insecure_channel(agent_endpoint) as agent_channel:
worker = ArmoniKWorker(agent_channel, processor, logger=logger)
logger.info("Worker Connected")
worker.start(worker_endpoint)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion packages/python/proto2python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mkdir -p $ARMONIK_WORKER $ARMONIK_CLIENT $ARMONIK_COMMON $PACKAGE_PATH
python -m pip install --upgrade pip
python -m venv $PYTHON_VENV
source $PYTHON_VENV/bin/activate
python -m pip install build grpcio grpcio-tools click
python -m pip install build grpcio grpcio-tools click seqlog

unset proto_files
for proto in ${armonik_worker_files[@]}; do
Expand Down
1 change: 1 addition & 0 deletions packages/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
dependencies = [
"grpcio",
"grpcio-tools",
"seqlog"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion packages/python/src/armonik/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .submitter import ArmoniKSubmitter
from .tasks import ArmoniKTasks
from .tasks import ArmoniKTasks
9 changes: 4 additions & 5 deletions packages/python/src/armonik/client/submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from grpc import Channel

from ..common.helpers import get_task_filter
from ..common.objects import Configuration, TaskOptions, TaskStatus, TaskDefinition, Task
from ..common import get_task_filter, TaskOptions, TaskDefinition, Task, TaskStatus
from ..protogen.client.submitter_service_pb2_grpc import SubmitterStub
from ..protogen.common.objects_pb2 import Empty, TaskRequest, ResultRequest, DataChunk, InitTaskRequest, \
TaskRequestHeader, Configuration
Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(self, grpc_channel: Channel):
self._client = SubmitterStub(grpc_channel)

def get_service_configuration(self) -> Configuration:
return Configuration(self._client.GetServiceConfiguration(Empty()))
return self._client.GetServiceConfiguration(Empty())

def create_session(self, default_task_options: TaskOptions, partition_ids: Optional[List[str]] = None) -> str:
if partition_ids is None:
Expand Down Expand Up @@ -96,7 +95,7 @@ def get_task_status(self, task_ids: List[str]) -> Dict[str, TaskStatus]:
request = GetTaskStatusRequest()
request.task_ids.extend(task_ids)
reply = self._client.GetTaskStatus(request)
return dict([(s.task_id, TaskStatus(s.status)) for s in reply.id_statuses])
return dict([(s.task_id, s.status) for s in reply.id_statuses])

def wait_for_completion(self,
session_ids: Optional[List[str]] = None,
Expand All @@ -105,7 +104,7 @@ def wait_for_completion(self,
excluded_statuses: Optional[List[TaskStatus]] = None,
stop_on_first_task_error: bool = False,
stop_on_first_task_cancellation: bool = False) -> Dict[TaskStatus, int]:
return dict([(TaskStatus(sc.status), sc.count) for sc in self._client.WaitForCompletion(
return dict([(sc.status, sc.count) for sc in self._client.WaitForCompletion(
WaitRequest(filter=get_task_filter(session_ids, task_ids, included_statuses, excluded_statuses),
stop_on_first_task_error=stop_on_first_task_error,
stop_on_first_task_cancellation=stop_on_first_task_cancellation)).values])
Expand Down
4 changes: 2 additions & 2 deletions packages/python/src/armonik/client/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from grpc import Channel

from ..common.objects import Task, TaskStatus, TaskOptions
from ..common import Task, TaskOptions
from ..protogen.client.tasks_service_pb2_grpc import TasksStub
from ..protogen.common.tasks_common_pb2 import GetTaskRequest

Expand All @@ -19,7 +19,7 @@ def get_task(self, task_id: str) -> Task:
task.parent_task_ids.extend(raw.parent_task_ids)
task.data_dependencies.extend(raw.data_dependencies)
task.expected_output_ids.extend(raw.expected_output_ids)
task.status = TaskStatus(raw.status)
task.status = raw.status
task.status_message = raw.status_message
task.options = TaskOptions.from_message(raw.options)
task.retry_of_ids.extend(raw.retry_of_ids)
Expand Down
5 changes: 3 additions & 2 deletions packages/python/src/armonik/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration
from .objects import Task, TaskDefinition, TaskOptions, Output
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter
from .objects import Task, TaskDefinition, TaskOptions, Output
from .enumwrapper import HealthCheckStatus, TaskStatus
26 changes: 26 additions & 0 deletions packages/python/src/armonik/common/enumwrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import enum

from ..protogen.common.task_status_pb2 import *
from ..protogen.common.worker_common_pb2 import HealthCheckReply

# This file is necessary because the grpc values don't have the proper type


class HealthCheckStatus(enum.Enum):
NOT_SERVING = HealthCheckReply.NOT_SERVING
SERVING = HealthCheckReply.SERVING
UNKNOWN = HealthCheckReply.UNKNOWN


class TaskStatus(enum.Enum):
CANCELLED = TASK_STATUS_CANCELLED
CANCELLING = TASK_STATUS_CANCELLING
COMPLETED = TASK_STATUS_COMPLETED
CREATING = TASK_STATUS_CREATING
DISPATCHED = TASK_STATUS_DISPATCHED
ERROR = TASK_STATUS_ERROR
PROCESSED = TASK_STATUS_PROCESSED
PROCESSING = TASK_STATUS_PROCESSING
SUBMITTED = TASK_STATUS_SUBMITTED
TIMEOUT = TASK_STATUS_TIMEOUT
UNSPECIFIED = TASK_STATUS_UNSPECIFIED
2 changes: 1 addition & 1 deletion packages/python/src/armonik/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import google.protobuf.duration_pb2 as duration
import google.protobuf.timestamp_pb2 as timestamp

from .objects import TaskStatus
from ..protogen.common.submitter_common_pb2 import TaskFilter
from .enumwrapper import TaskStatus


def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None,
Expand Down
11 changes: 5 additions & 6 deletions packages/python/src/armonik/common/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from datetime import timedelta, datetime
from typing import Optional, List

from client.tasks import ArmoniKTasks
from protogen.common.tasks_common_pb2 import TaskRaw
from ..protogen.common.tasks_common_pb2 import TaskRaw
from .helpers import duration_to_timedelta, timedelta_to_duration, timestamp_to_datetime
from ..protogen.common.objects_pb2 import Empty, Output as WorkerOutput
from ..protogen.common.task_status_pb2 import *
from .enumwrapper import TaskStatus


@dataclass()
Expand Down Expand Up @@ -84,7 +83,7 @@ class Task:
data_dependencies: Optional[List[str]] = None
expected_output_ids: Optional[List[str]] = None
retry_of_ids: Optional[List[str]] = None
status: TaskStatus = TaskStatus.TASK_STATUS_UNSPECIFIED
status: TaskStatus = TaskStatus.UNSPECIFIED
status_message: Optional[str] = None
options: Optional[TaskOptions] = None
created_at: Optional[datetime] = None
Expand All @@ -97,7 +96,7 @@ class Task:
received_at: Optional[datetime] = None
acquired_at: Optional[datetime] = None

def refresh(self, task_client: ArmoniKTasks) -> None:
def refresh(self, task_client) -> None:
result = task_client.get_task(self.id)
self.session_id = result.session_id
self.owner_pod_id = result.owner_pod_id
Expand Down Expand Up @@ -128,7 +127,7 @@ def from_message(cls, task_raw: TaskRaw) -> "Task":
data_dependencies=list(task_raw.data_dependencies),
expected_output_ids=list(task_raw.expected_output_ids),
retry_of_ids=list(task_raw.retry_of_ids),
status=task_raw.status,
status=TaskStatus(task_raw.status),
status_message=task_raw.status_message,
options=task_raw.options,
created_at=timestamp_to_datetime(task_raw.created_at),
Expand Down
3 changes: 2 additions & 1 deletion packages/python/src/armonik/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .worker import ArmoniKWorker
from .taskhandler import TaskHandler
from .taskhandler import TaskHandler
from .seqlogger import get_worker_logger
18 changes: 18 additions & 0 deletions packages/python/src/armonik/worker/seqlogger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from logging import Logger, getLogger, Formatter
from seqlog import ConsoleStructuredLogHandler
from typing import Dict

_worker_loggers: Dict[str, Logger] = {}


def get_worker_logger(name: str, level: int) -> Logger:
if name in _worker_loggers:
return _worker_loggers[name]
logger = getLogger(name)
logger.handlers.clear()
handler = ConsoleStructuredLogHandler()
handler.setFormatter(Formatter(style="{"))
logger.addHandler(handler)
logger.setLevel(level)
_worker_loggers[name] = logger
return _worker_loggers[name]
4 changes: 2 additions & 2 deletions packages/python/src/armonik/worker/taskhandler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import uuid
from typing import Optional, Dict, List, Tuple, Union

from ..common.objects import TaskOptions, Configuration, TaskDefinition, Task
from ..common import TaskOptions, TaskDefinition, Task
from ..protogen.common.agent_common_pb2 import Result, CreateTaskRequest
from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader
from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration
from ..protogen.worker.agent_service_pb2_grpc import AgentStub


Expand Down
13 changes: 8 additions & 5 deletions packages/python/src/armonik/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import logging
import traceback
from concurrent import futures
from typing import Callable

import grpc
from grpc import Channel

from ..common.objects import Output
from .seqlogger import get_worker_logger
from ..common import Output, HealthCheckStatus
from ..protogen.common.worker_common_pb2 import ProcessReply, HealthCheckReply
from ..protogen.worker.agent_service_pb2_grpc import AgentStub
from ..protogen.worker.worker_service_pb2_grpc import WorkerServicer, add_WorkerServicer_to_server
from ..worker.taskhandler import TaskHandler
from .taskhandler import TaskHandler


class ArmoniKWorker(WorkerServicer):
def __init__(self, agent_channel: Channel, processing_function: Callable[[TaskHandler], Output], health_check: Callable[[], HealthCheckReply.ServingStatus] = lambda: HealthCheckReply.SERVING):
def __init__(self, agent_channel: Channel, processing_function: Callable[[TaskHandler], Output], health_check: Callable[[], HealthCheckStatus] = lambda: HealthCheckStatus.SERVING, logger=get_worker_logger("ArmoniKWorker", logging.INFO)):
self.health_check = health_check
self.processing_function = processing_function
self._client = AgentStub(agent_channel)
self._logger = logger

def start(self, endpoint: str):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
Expand All @@ -30,7 +33,7 @@ def Process(self, request_iterator, context) -> ProcessReply:
task_handler = TaskHandler.create(request_iterator, self._client)
return ProcessReply(output=self.processing_function(task_handler).to_message())
except Exception as e:
print(f"Failed task {''.join(traceback.format_exception(e))}")
self._logger.exception(f"Failed task {''.join(traceback.format_exception(e))}", exc_info=e)

def HealthCheck(self, request, context) -> HealthCheckReply:
return HealthCheckReply(status=self.health_check())
return HealthCheckReply(status=self.health_check().value)

0 comments on commit 824e870

Please sign in to comment.