Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Agent #1712

Merged
merged 12 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from concurrent import futures

import click
import grpc
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server
from grpc import aio

from flytekit.extend.backend.agent_service import AgentService
from flytekit.extend.backend.agent_service import AsyncAgentService

_serve_help = """Start a grpc server for the agent service."""

Expand Down Expand Up @@ -37,10 +37,16 @@ def serve(_: click.Context, port, worker, timeout):
"""
Start a grpc server for the agent service.
"""
import asyncio

asyncio.run(_start_grpc_server(port, worker, timeout))


async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting the agent service...", fg="blue")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker))
add_AsyncAgentServiceServicer_to_server(AgentService(), server)
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))
add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)

server.add_insecure_port(f"[::]:{port}")
server.start()
server.wait_for_termination(timeout=timeout)
await server.start()
await server.wait_for_termination(timeout)
18 changes: 11 additions & 7 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
Expand All @@ -17,37 +19,39 @@
from flytekit.models.task import TaskTemplate


class AgentService(AsyncAgentServiceServicer):
def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
class AsyncAgentService(AsyncAgentServiceServicer):
async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(context, tmp.type)
if agent is None:
return CreateTaskResponse()
return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
return await asyncio.to_thread(
agent.create, context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
)
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
return agent.get(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.get, context=context, resource_meta=request.resource_meta)
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return DeleteTaskResponse()
return agent.delete(context=context, resource_meta=request.resource_meta)
return asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta)
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
Expand Down
24 changes: 12 additions & 12 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import typing
from dataclasses import asdict, dataclass
Expand All @@ -21,14 +22,15 @@

import flytekit.models.interface as interface_models
from flytekit import PythonFunctionTask
from flytekit.extend.backend.agent_service import AgentService
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, AsyncAgentExecutorMixin, is_terminal_state
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import Identifier, ResourceType
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

dummy_id = "dummy_id"
loop = asyncio.get_event_loop()


@dataclass
Expand Down Expand Up @@ -117,24 +119,22 @@ def __init__(self, **kwargs):


def test_agent_server():
service = AgentService()
service = AsyncAgentService()
ctx = MagicMock(spec=grpc.ServicerContext)
request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl()
)

metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")
assert service.CreateTask(request, ctx).resource_meta == metadata_bytes
assert (
service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx).resource.state
== SUCCEEDED
)
assert (
service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
== DeleteTaskResponse()
)
res = loop.run_until_complete(service.CreateTask(request, ctx))
assert res.resource_meta == metadata_bytes

res = loop.run_until_complete(service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx))
assert res.resource.state == SUCCEEDED

loop.run_until_complete(service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx))

res = service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx)
res = loop.run_until_complete(service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx))
assert res.resource.state == PERMANENT_FAILURE


Expand Down