Skip to content

Commit

Permalink
Async agent delete function for while loop case (flyteorg#1802)
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent 5c23325 commit ba571cd
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 60 deletions.
14 changes: 3 additions & 11 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
CreateTaskRequest,
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer

Expand All @@ -24,10 +22,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(context, tmp.type)
agent = AgentRegistry.get_agent(tmp.type)
logger.info(f"{tmp.type} agent start creating the job")
if agent is None:
return CreateTaskResponse()
if agent.asynchronous:
try:
return await agent.async_create(
Expand All @@ -50,10 +46,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon

async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start checking the status of the job")
if agent is None:
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
if agent.asynchronous:
try:
return await agent.async_get(context=context, resource_meta=request.resource_meta)
Expand All @@ -72,10 +66,8 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)

async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
if agent is None:
return DeleteTaskResponse()
if agent.asynchronous:
try:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
Expand Down
105 changes: 60 additions & 45 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,9 @@ def register(agent: AgentBase):
logger.info(f"Registering an agent for task type {agent.task_type}")

@staticmethod
def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]:
def get_agent(task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
logger.error(f"Cannot find agent for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find the agent for task type [{task_type}]")
return None
raise ValueError(f"Unrecognized task type {task_type}")
return AgentRegistry._REGISTRY[task_type]


Expand All @@ -136,9 +133,9 @@ def convert_to_flyte_state(state: str) -> State:
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
if state in ["failed", "timedout", "canceled"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
elif state in ["done", "succeeded", "success"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
Expand All @@ -158,61 +155,79 @@ class AsyncAgentExecutorMixin:
Task should inherit from this class if the task can be run in the agent.
"""

def execute(self, **kwargs) -> typing.Any:
from unittest.mock import MagicMock
_is_canceled = None
_agent = None
_entity = None

def execute(self, **kwargs) -> typing.Any:
from flytekit.tools.translator import get_serializable

entity = typing.cast(PythonTask, self)
m: OrderedDict = OrderedDict()
dummy_context = MagicMock(spec=grpc.ServicerContext)
cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity)
agent = AgentRegistry.get_agent(dummy_context, cp_entity.template.type)
self._entity = typing.cast(PythonTask, self)
task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template
self._agent = AgentRegistry.get_agent(task_template.type)

if agent is None:
raise Exception("Cannot find the agent for the task")
literals = {}
res = asyncio.run(self._create(task_template, kwargs))
res = asyncio.run(self._get(resource_meta=res.resource_meta))

if res.resource.state != SUCCEEDED:
raise Exception(f"Failed to run the task {self._entity.name}")

return LiteralMap.from_flyte_idl(res.resource.outputs)

async def _create(
self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None
) -> CreateTaskResponse:
ctx = FlyteContext.current_context()
for k, v in kwargs.items():
literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type)
grpc_ctx = _get_grpc_context()

# Convert python inputs to literals
literals = {}
for k, v in inputs.items():
literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type)
inputs = LiteralMap(literals) if literals else None
output_prefix = ctx.file_access.get_random_local_directory()
cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity)
if agent.asynchronous:
res = asyncio.run(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs))

if self._agent.asynchronous:
res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs)
else:
res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs)
signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta))
res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs)

signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore
return res

async def _get(self, resource_meta: bytes) -> GetTaskResponse:
state = RUNNING
metadata = res.resource_meta
grpc_ctx = _get_grpc_context()

progress = Progress(transient=True)
task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None)
task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None)
with progress:
while not is_terminal_state(state):
progress.start_task(task)
time.sleep(1)
if agent.asynchronous:
res = asyncio.run(agent.async_get(dummy_context, metadata))
if self._agent.asynchronous:
res = await self._agent.async_get(grpc_ctx, resource_meta)
if self._is_canceled:
await self._is_canceled
sys.exit(1)
else:
res = agent.get(dummy_context, metadata)
res = self._agent.get(grpc_ctx, resource_meta)
state = res.resource.state
logger.info(f"Task state: {state}")
return res

def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any:
grpc_ctx = _get_grpc_context()
if self._agent.asynchronous:
if self._is_canceled is None:
self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta))
else:
self._agent.delete(grpc_ctx, resource_meta)
sys.exit(1)

if state != SUCCEEDED:
raise Exception(f"Failed to run the task {entity.name}")

return LiteralMap.from_flyte_idl(res.resource.outputs)
def _get_grpc_context():
from unittest.mock import MagicMock

def signal_handler(
self,
agent: AgentBase,
context: grpc.ServicerContext,
resource_meta: bytes,
signum: int,
frame: FrameType,
) -> typing.Any:
if agent.asynchronous:
asyncio.run(agent.async_delete(context, resource_meta))
else:
agent.delete(context, resource_meta)
sys.exit(1)
grpc_ctx = MagicMock(spec=grpc.ServicerContext)
return grpc_ctx
2 changes: 1 addition & 1 deletion plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self):
mock_instance.cancel_job.return_value = MockJob()

ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "bigquery_query_job_task")
agent = AgentRegistry.get_agent("bigquery_query_job_task")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
Expand Down
29 changes: 26 additions & 3 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskRequest,
Expand All @@ -23,7 +24,13 @@
import flytekit.models.interface as interface_models
from flytekit import PythonFunctionTask
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, AsyncAgentExecutorMixin, is_terminal_state
from flytekit.extend.backend.base_agent import (
AgentBase,
AgentRegistry,
AsyncAgentExecutorMixin,
convert_to_flyte_state,
is_terminal_state,
)
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import Identifier, ResourceType
from flytekit.models.literals import LiteralMap
Expand Down Expand Up @@ -97,7 +104,7 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT

def test_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "dummy")
agent = AgentRegistry.get_agent("dummy")
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")
assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes
assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED
Expand All @@ -114,7 +121,7 @@ def __init__(self, **kwargs):
t.execute()

t._task_type = "non-exist-type"
with pytest.raises(Exception, match="Cannot find the agent for the task"):
with pytest.raises(Exception, match="Unrecognized task type non-exist-type"):
t.execute()


Expand Down Expand Up @@ -147,3 +154,19 @@ def test_is_terminal_state():
assert is_terminal_state(PERMANENT_FAILURE)
assert is_terminal_state(PERMANENT_FAILURE)
assert not is_terminal_state(RUNNING)


def test_convert_to_flyte_state():
assert convert_to_flyte_state("FAILED") == RETRYABLE_FAILURE
assert convert_to_flyte_state("TIMEDOUT") == RETRYABLE_FAILURE
assert convert_to_flyte_state("CANCELED") == RETRYABLE_FAILURE

assert convert_to_flyte_state("DONE") == SUCCEEDED
assert convert_to_flyte_state("SUCCEEDED") == SUCCEEDED
assert convert_to_flyte_state("SUCCESS") == SUCCEEDED

assert convert_to_flyte_state("RUNNING") == RUNNING

invalid_state = "INVALID_STATE"
with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"):
convert_to_flyte_state(invalid_state)

0 comments on commit ba571cd

Please sign in to comment.