Skip to content

Commit

Permalink
Added armonik python wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrasseur-aneo committed Mar 21, 2023
1 parent 380733d commit 44a423d
Show file tree
Hide file tree
Showing 20 changed files with 831 additions and 24 deletions.
4 changes: 4 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ indent_size = 2
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true

[packages/python/**]
indent_size = 4
max_line_length = off
1 change: 1 addition & 0 deletions examples/python/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.whl
14 changes: 14 additions & 0 deletions examples/python/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM python:3.7-slim AS builder
WORKDIR /app
RUN python -m venv .venv && .venv/bin/pip install --no-cache-dir -U pip setuptools
COPY *.whl ./
RUN ( .venv/bin/pip install --no-cache-dir *.whl || .venv/bin/pip install --no-cache-dir armonik ) && find /app/.venv \( -type d -a -name test -o -name tests \) -o \( -type f -a -name '*.pyc' -o -name '*.pyo' \) -exec rm -rf '{}' \+

FROM python:3.7-slim
WORKDIR /app
RUN groupadd --gid 5000 armonikuser && useradd --home-dir /home/armonikuser --create-home --uid 5000 --gid 5000 --shell /bin/sh --skel /dev/null armonikuser && mkdir /cache && chown armonikuser: /cache
USER armonikuser
ENV PATH="/app/.venv/bin:$PATH" PYTHONUNBUFFERED=1
COPY --from=builder /app /app
COPY . .
ENTRYPOINT ["python", "worker.py"]
Empty file added examples/python/client.py
Empty file.
Empty file added examples/python/common.py
Empty file.
Empty file added examples/python/worker.py
Empty file.
5 changes: 3 additions & 2 deletions packages/python/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pkg/
generated/
armonik/protogen
build/
*.egg-info
*.egg-info
**/_version.py
Empty file.
Empty file.
190 changes: 190 additions & 0 deletions packages/python/armonik/client/submitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import uuid
from typing import Optional, List, Tuple, Dict

from grpc import Channel

from ..common.helpers import get_task_filter
from ..common.objects import Configuration, TaskOptions, TaskStatus, TaskDefinition, Task
from ..protogen.client.submitter_service_pb2_grpc import SubmitterStub
from ..protogen.common.objects_pb2 import Empty, TaskRequest, ResultRequest, DataChunk, InitTaskRequest, \
TaskRequestHeader, Configuration
from ..protogen.common.submitter_common_pb2 import CreateSessionRequest, GetTaskStatusRequest, CreateLargeTaskRequest, \
WaitRequest

"""
rpc GetServiceConfiguration(Empty) returns (Configuration);
rpc CreateSession(CreateSessionRequest) returns (CreateSessionReply);
rpc CancelSession(Session) returns (Empty);
rpc CreateSmallTasks(CreateSmallTaskRequest) returns (CreateTaskReply);
rpc CreateLargeTasks(stream CreateLargeTaskRequest) returns (CreateTaskReply);
rpc ListTasks(TaskFilter) returns (TaskIdList);
rpc ListSessions(SessionFilter) returns (SessionIdList);
rpc CountTasks(TaskFilter) returns (Count);
rpc TryGetResultStream(ResultRequest) returns (stream ResultReply);
rpc TryGetTaskOutput(TaskOutputRequest) returns (Output);
rpc WaitForAvailability(ResultRequest) returns (AvailabilityReply) {
option deprecated = true;
}
rpc WaitForCompletion(WaitRequest) returns (Count);
rpc CancelTasks(TaskFilter) returns (Empty);
rpc GetTaskStatus(GetTaskStatusRequest) returns (GetTaskStatusReply);
rpc GetResultStatus(GetResultStatusRequest) returns (GetResultStatusReply) {
option deprecated = true;
}
"""


class ArmoniKSubmitter:
def __init__(self, grpc_channel: Channel):
self._client = SubmitterStub(grpc_channel)

def get_service_configuration(self) -> Configuration:
return Configuration(self._client.GetServiceConfiguration(Empty()))

def create_session(self, default_task_options: TaskOptions, partition_ids: Optional[List[str]] = None) -> str:
if partition_ids is None:
partition_ids = []
request = CreateSessionRequest(default_task_option=default_task_options)
for partition in partition_ids:
request.partition_ids.append(partition)
return self._client.CreateSession(request).session_id

def submit(self, session_id: str, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]:
task_requests = []

for t in tasks:
task_request = TaskRequest()
task_request.expected_output_keys.extend(t.expected_outputs)
if t.data_dependencies is not None:
task_request.data_dependencies.extend(t.data_dependencies)
task_request.payload = t.payload
task_requests.append(task_request)

configuration = self.get_service_configuration()
create_tasks_reply = self._client.CreateLargeTasks(
to_request_stream(task_requests, session_id, task_options, configuration.data_chunk_max_size))
ret = create_tasks_reply.WhichOneof("Response")
if ret is None or ret == "error":
raise Exception(f'Issue with server when submitting tasks : {create_tasks_reply.error}')
elif ret == "creation_status_list":
tasks_created = []
tasks_creation_failed = []
for creation_status in create_tasks_reply.creation_status_list.creation_statuses:
if creation_status.WhichOneof("Status") == "task_info":
tasks_created.append(Task(id=creation_status.task_info.task_id, session_id=session_id,
expected_output_ids=[k for k in
creation_status.task_info.expected_output_keys],
data_dependencies=[k for k in
creation_status.task_info.data_dependencies]))
else:
tasks_creation_failed.append(creation_status.error)
else:
raise Exception("Unknown value")
return tasks_created, tasks_creation_failed

def list_tasks(self, session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None,
included_statuses: Optional[List[TaskStatus]] = None,
excluded_statuses: Optional[List[TaskStatus]] = None) -> List[str]:
return [t for t in self._client.ListTasks(
get_task_filter(session_ids, task_ids, included_statuses, excluded_statuses)).task_ids]

def get_task_status(self, task_ids: List[str]) -> Dict[str, TaskStatus]:
request = GetTaskStatusRequest()
request.task_ids.extend(task_ids)
reply = self._client.GetTaskStatus(request)
return dict([(s.task_id, TaskStatus(s.status)) for s in reply.id_statuses])

def wait_for_completion(self,
session_ids: Optional[List[str]] = None,
task_ids: Optional[List[str]] = None,
included_statuses: Optional[List[TaskStatus]] = None,
excluded_statuses: Optional[List[TaskStatus]] = None,
stop_on_first_task_error: bool = False,
stop_on_first_task_cancellation: bool = False) -> Dict[TaskStatus, int]:
return dict([(TaskStatus(sc.status), sc.count) for sc in self._client.WaitForCompletion(
WaitRequest(filter=get_task_filter(session_ids, task_ids, included_statuses, excluded_statuses),
stop_on_first_task_error=stop_on_first_task_error,
stop_on_first_task_cancellation=stop_on_first_task_cancellation)).values])

def get_result(self, session_id: str, result_id) -> bytes:
result_request = ResultRequest(
result_id=result_id,
session=session_id
)
streaming_call = self._client.TryGetResultStream(result_request)
result = bytearray()
valid = True
for message in streaming_call:
ret = message.WhichOneof("type")
if ret is None:
raise Exception("Error with server")
elif ret == "result":
if message.result.WhichOneof("type") == "data":
result += message.result.data
valid = False
elif message.result.WhichOneof("type") == "data_complete":
valid = True
elif ret == "error":
raise Exception("Task in error")
else:
raise Exception("Unknown return type")
if valid:
return result
raise Exception("Incomplete Data")

def request_output_id(self, session_id: str) -> str:
return f"{session_id}%{uuid.uuid4()}"


def to_request_stream_internal(request, is_last, chunk_max_size):
req = CreateLargeTaskRequest(
init_task=InitTaskRequest(
header=TaskRequestHeader(
data_dependencies=request.data_dependencies,
expected_output_keys=request.expected_output_keys
)
)
)
yield req
start = 0
payload_length = len(request.payload)
if payload_length == 0:
req = CreateLargeTaskRequest(
task_payload=DataChunk(data=b'')
)
yield req
while start < payload_length:
chunk_size = min(chunk_max_size, payload_length - start)
req = CreateLargeTaskRequest(
task_payload=DataChunk(data=request.payload[start:start + chunk_size])
)
yield req
start += chunk_size
req = CreateLargeTaskRequest(
task_payload=DataChunk(data_complete=True)
)
yield req

if is_last:
req = CreateLargeTaskRequest(
init_task=InitTaskRequest(last_task=True)
)
yield req


def to_request_stream(requests, s_id, t_options, chunk_max_size):
req = CreateLargeTaskRequest(
init_request=CreateLargeTaskRequest.InitRequest(
session_id=s_id, task_options=t_options))
yield req
if len(requests) == 0:
return
for r in requests[:-1]:
for req in to_request_stream_internal(r, False, chunk_max_size):
yield req
for req in to_request_stream_internal(requests[-1], True, chunk_max_size):
yield req
26 changes: 26 additions & 0 deletions packages/python/armonik/client/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from grpc import Channel

from ..common.objects import Task, TaskStatus, TaskOptions
from ..protogen.client.tasks_service_pb2_grpc import TasksStub
from ..protogen.common.tasks_common_pb2 import GetTaskRequest


class ArmoniKTasks:
def __init__(self, grpc_channel: Channel):
self._client = TasksStub(grpc_channel)

def get_task(self, task_id: str) -> Task:
task_response = self._client.GetTask(GetTaskRequest(task_id=task_id))
task = Task()
raw = task_response.task
task.id = raw.id
task.session_id = raw.session_id
task.owner_pod_id = raw.owner_pod_id
task.parent_task_ids.extend(raw.parent_task_ids)
task.data_dependencies.extend(raw.data_dependencies)
task.expected_output_ids.extend(raw.expected_output_ids)
task.status = TaskStatus(raw.status)
task.status_message = raw.status_message
task.options = TaskOptions.from_message(raw.options)
task.retry_of_ids.extend(raw.retry_of_ids)
return task
Empty file.
51 changes: 51 additions & 0 deletions packages/python/armonik/common/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from datetime import timedelta, datetime
from math import floor
from typing import List, Optional

import google.protobuf.duration_pb2 as duration
import google.protobuf.timestamp_pb2 as timestamp

from .objects import TaskStatus
from ..protogen.common.submitter_common_pb2 import TaskFilter


def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None,
included_statuses: Optional[List[TaskStatus]] = None,
excluded_statuses: Optional[List[TaskStatus]] = None) -> TaskFilter:
if session_ids is not None and task_ids is not None:
raise ValueError("session_ids and task_ids cannot be defined at the same time")
if included_statuses is not None and excluded_statuses is not None:
raise ValueError("included_statuses and excluded_statuses cannot be defined at the same time")
task_filter = TaskFilter(
session=TaskFilter.IdsRequest() if session_ids is not None else None,
task=TaskFilter.IdsRequest() if task_ids is not None else None,
included=TaskFilter.StatusesRequest() if included_statuses is not None else None,
excluded=TaskFilter.StatusesRequest() if excluded_statuses is not None else None
)
if session_ids is not None:
task_filter.session.ids.extend(session_ids)
if task_ids is not None:
task_filter.task.ids.extend(task_ids)
if included_statuses is not None:
task_filter.included.statuses.extend([t.value for t in included_statuses])
if excluded_statuses is not None:
task_filter.excluded.statuses.extend([t.value for t in excluded_statuses])
return task_filter


def datetime_to_timestamp(time_stamp: datetime) -> timestamp.Timestamp:
secs, fracsec = divmod(time_stamp.timestamp(), 1)
return timestamp.Timestamp(seconds=secs, nanos=floor(fracsec * 1e9))


def timestamp_to_datetime(time_stamp: timestamp.Timestamp) -> datetime:
return datetime.utcfromtimestamp(time_stamp.seconds + time_stamp.nanos / 1e9)


def duration_to_timedelta(delta: duration.Duration) -> timedelta:
return timedelta(seconds=delta.seconds, microseconds=delta.nanos // 1000)


def timedelta_to_duration(delta: timedelta) -> duration.Duration:
secs, remainder = divmod(delta, timedelta(seconds=1))
return duration.Duration(seconds=secs, nanos=(remainder // timedelta(microseconds=1)) * 1000)
Loading

0 comments on commit 44a423d

Please sign in to comment.