-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9b051fe
commit 824e870
Showing
16 changed files
with
168 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import List, Union | ||
import json | ||
|
||
|
||
class Payload: | ||
def __init__(self, values: List[Union[float, str]], subtask_threshold=2): | ||
self.values = values | ||
self.subtask_threshold = subtask_threshold | ||
|
||
def serialize(self) -> bytes: | ||
return json.dumps({"values": self.values, "subtask_threshold": self.subtask_threshold}).encode("utf-8") | ||
|
||
@classmethod | ||
def deserialize(cls, payload: bytes) -> "Payload": | ||
return cls(**json.loads(payload.decode("utf-8"))) | ||
|
||
|
||
class Result: | ||
def __init__(self, value: float): | ||
self.value = value | ||
|
||
def serialize(self) -> bytes: | ||
return json.dumps({"value": self.value}).encode("utf-8") | ||
|
||
@classmethod | ||
def deserialize(cls, payload: bytes) -> "Result": | ||
return cls(**json.loads(payload.decode("utf-8"))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,75 @@ | ||
import logging | ||
import os | ||
|
||
import grpc | ||
from armonik.worker import ArmoniKWorker, TaskHandler | ||
from armonik.common import Output | ||
from armonik.worker import get_worker_logger | ||
from armonik.common import Output, TaskDefinition | ||
from typing import List | ||
|
||
from common import Payload, Result | ||
|
||
logger = get_worker_logger("ArmoniKWorker", level=logging.INFO) | ||
|
||
|
||
# Actual computation | ||
# Task processing | ||
def processor(task_handler: TaskHandler) -> Output: | ||
payload = Payload.deserialize(task_handler.payload) | ||
# No values | ||
if len(payload.values) == 0: | ||
if len(task_handler.expected_results) > 0: | ||
task_handler.send_result(task_handler.expected_results[0], Result(0.0).serialize()) | ||
logger.info("No values") | ||
return Output() | ||
|
||
if isinstance(payload.values[0], str): | ||
# Aggregation task | ||
results = [Result.deserialize(task_handler.data_dependencies[r]).value for r in payload.values] | ||
task_handler.send_result(task_handler.expected_results[0], Result(aggregate(results)).serialize()) | ||
logger.info(f"Aggregated {len(results)} values") | ||
return Output() | ||
|
||
if len(payload.values) <= 1 or len(payload.values) <= payload.subtask_threshold: | ||
# Compute | ||
task_handler.send_result(task_handler.expected_results[0], Result(aggregate(payload.values)).serialize()) | ||
logger.info(f"Computed {len(payload.values)} values") | ||
return Output() | ||
|
||
# Subtasking | ||
pivot = len(payload.values) // 2 | ||
lower = payload.values[:pivot] | ||
upper = payload.values[pivot:] | ||
subtasks = [] | ||
for vals in [lower, upper]: | ||
new_payload = Payload(values=vals, subtask_threshold=payload.subtask_threshold).serialize() | ||
subtasks.append(TaskDefinition(payload=new_payload, expected_outputs=[task_handler.request_output_id()])) | ||
aggregate_dependencies = [s.expected_outputs[0] for s in subtasks] | ||
subtasks.append(TaskDefinition(Payload(values=aggregate_dependencies).serialize(), expected_outputs=task_handler.expected_results, data_dependencies=aggregate_dependencies)) | ||
if len(subtasks) > 0: | ||
submitted, errors = task_handler.create_tasks(subtasks) | ||
if len(errors) > 0: | ||
message = f"Errors while submitting subtasks : {', '.join(errors)}" | ||
logger.error(message) | ||
return Output(message) | ||
logger.info(f"Submitted {len(submitted)} subtasks") | ||
return Output() | ||
|
||
|
||
def aggregate(values: List[float]) -> float: | ||
return sum(values) | ||
|
||
|
||
def main(): | ||
worker_scheme = "unix://" if os.getenv("ComputePlane__WorkerChannel__SocketType", "unixdomainsocket") == "unixdomainsocket" else "http://" | ||
agent_scheme = "unix://" if os.getenv("ComputePlane__AgentChannel__SocketType", "unixdomainsocket") == "unixdomainsocket" else "http://" | ||
worker_endpoint = worker_scheme+os.getenv("ComputePlane__WorkerChannel__Address", "/cache/armonik_worker.sock") | ||
agent_endpoint = agent_scheme+os.getenv("ComputePlane__AgentChannel__Address", "/cache/armonik_agent.sock") | ||
logger.info("Worker Started") | ||
with grpc.insecure_channel(agent_endpoint) as agent_channel: | ||
worker = ArmoniKWorker(agent_channel, processor, logger=logger) | ||
logger.info("Worker Connected") | ||
worker.start(worker_endpoint) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ classifiers = [ | |
dependencies = [ | ||
"grpcio", | ||
"grpcio-tools", | ||
"seqlog" | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .submitter import ArmoniKSubmitter | ||
from .tasks import ArmoniKTasks | ||
from .tasks import ArmoniKTasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration | ||
from .objects import Task, TaskDefinition, TaskOptions, Output | ||
from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter | ||
from .objects import Task, TaskDefinition, TaskOptions, Output | ||
from .enumwrapper import HealthCheckStatus, TaskStatus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import enum | ||
|
||
from ..protogen.common.task_status_pb2 import * | ||
from ..protogen.common.worker_common_pb2 import HealthCheckReply | ||
|
||
# This file is necessary because the grpc values don't have the proper type | ||
|
||
|
||
class HealthCheckStatus(enum.Enum): | ||
NOT_SERVING = HealthCheckReply.NOT_SERVING | ||
SERVING = HealthCheckReply.SERVING | ||
UNKNOWN = HealthCheckReply.UNKNOWN | ||
|
||
|
||
class TaskStatus(enum.Enum): | ||
CANCELLED = TASK_STATUS_CANCELLED | ||
CANCELLING = TASK_STATUS_CANCELLING | ||
COMPLETED = TASK_STATUS_COMPLETED | ||
CREATING = TASK_STATUS_CREATING | ||
DISPATCHED = TASK_STATUS_DISPATCHED | ||
ERROR = TASK_STATUS_ERROR | ||
PROCESSED = TASK_STATUS_PROCESSED | ||
PROCESSING = TASK_STATUS_PROCESSING | ||
SUBMITTED = TASK_STATUS_SUBMITTED | ||
TIMEOUT = TASK_STATUS_TIMEOUT | ||
UNSPECIFIED = TASK_STATUS_UNSPECIFIED |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .worker import ArmoniKWorker | ||
from .taskhandler import TaskHandler | ||
from .taskhandler import TaskHandler | ||
from .seqlogger import get_worker_logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from logging import Logger, getLogger, Formatter | ||
from seqlog import ConsoleStructuredLogHandler | ||
from typing import Dict | ||
|
||
_worker_loggers: Dict[str, Logger] = {} | ||
|
||
|
||
def get_worker_logger(name: str, level: int) -> Logger: | ||
if name in _worker_loggers: | ||
return _worker_loggers[name] | ||
logger = getLogger(name) | ||
logger.handlers.clear() | ||
handler = ConsoleStructuredLogHandler() | ||
handler.setFormatter(Formatter(style="{")) | ||
logger.addHandler(handler) | ||
logger.setLevel(level) | ||
_worker_loggers[name] = logger | ||
return _worker_loggers[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters