From 98d722fec491ecd68ba3ea92dfae9f676d32f677 Mon Sep 17 00:00:00 2001 From: "Ethan Brown (Domino)" <111539728+ddl-ebrown@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:49:05 -0700 Subject: [PATCH] Propagate custom_info Dict through agent Resource (#2426) * Propagate custom_info Dict through agent Resource - The agent defines a Resource return type with values: * outputs * message * log_links * phase These are all a part of the underlying protobuf contract defined in flyteidl. However, the message field custom_info from the protobuf is not here google.protobuf.Struct custom_info https://github.com/flyteorg/flyte/blob/519080b6e4e53fc0e216b5715ad9b5b5270f35c0/flyteidl/protos/flyteidl/admin/agent.proto#L140 This field was added in https://github.com/flyteorg/flyte/pull/4874 but never made it into the corresponding flytekit PR https://github.com/flyteorg/flytekit/pull/2146 - It's useful for agents to return additional metadata about the job, and it looks like custom_info is the intended location - Make a minor refactor to how the agent responds to requests that return Resource by implementing to_flyte_idl / from_flyte_idl directly Signed-off-by: ddl-ebrown Signed-off-by: ddl-rliu * Fix test Signed-off-by: ddl-rliu --------- Signed-off-by: ddl-ebrown Signed-off-by: ddl-rliu Co-authored-by: ddl-rliu --- flytekit/extend/backend/agent_service.py | 27 +------ flytekit/extend/backend/base_agent.py | 36 ++++++++- tests/flytekit/unit/extend/test_agent.py | 96 +++++++++++++++++++++--- 3 files changed, 124 insertions(+), 35 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index a92cef8e36..9b444d101e 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -16,7 +16,6 @@ GetTaskResponse, ListAgentsRequest, ListAgentsResponse, - Resource, ) from flyteidl.service.agent_pb2_grpc import ( AgentMetadataServiceServicer, @@ -25,8 +24,7 @@ ) from prometheus_client import Counter, Summary -from flytekit import FlyteContext, logger -from flytekit.core.type_engine import TypeEngine +from flytekit import logger from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap @@ -136,16 +134,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) logger.info(f"{agent.name} start checking the status of the job") res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) - if res.outputs is None: - outputs = None - elif isinstance(res.outputs, LiteralMap): - outputs = res.outputs.to_flyte_idl() - else: - ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) - return GetTaskResponse( - resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) - ) + return GetTaskResponse(resource=res.to_flyte_idl()) @record_agent_metrics async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: @@ -178,17 +167,7 @@ async def ExecuteTaskSync( agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix ) - if res.outputs is None: - outputs = None - elif isinstance(res.outputs, LiteralMap): - outputs = res.outputs.to_flyte_idl() - else: - ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) - - header = ExecuteTaskSyncResponseHeader( - resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) - ) + header = ExecuteTaskSyncResponseHeader(resource=res.to_flyte_idl()) yield ExecuteTaskSyncResponse(header=header) request_success_count.labels(task_type=task_type, operation=do_operation).inc() except Exception as e: diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 9f155da321..f8264edc92 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -12,9 +12,12 @@ from typing import Any, Dict, List, Optional, Union from flyteidl.admin.agent_pb2 import Agent +from flyteidl.admin.agent_pb2 import Resource as _Resource from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct from rich.logging import RichHandler from rich.progress import Progress @@ -28,6 +31,7 @@ from flytekit.exceptions.user import FlyteUserException from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template from flytekit.loggers import set_flytekit_log_properties +from flytekit.models import common from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskExecutionMetadata, TaskTemplate @@ -76,7 +80,7 @@ def decode(cls, data: bytes) -> "ResourceMeta": @dataclass -class Resource: +class Resource(common.FlyteIdlEntity): """ This is the output resource of the job. @@ -91,6 +95,36 @@ class Resource: message: Optional[str] = None log_links: Optional[List[TaskLog]] = None outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None + custom_info: Optional[typing.Dict[str, Any]] = None + + def to_flyte_idl(self) -> _Resource: + if self.outputs is None: + outputs = None + elif isinstance(self.outputs, LiteralMap): + outputs = self.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, self.outputs) + + return _Resource( + phase=self.phase, + message=self.message, + log_links=self.log_links, + outputs=outputs, + custom_info=(json_format.Parse(json.dumps(self.custom_info), Struct()) if self.custom_info else None), + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _Resource): + return cls( + phase=pb2_object.phase, + message=pb2_object.message, + log_links=pb2_object.log_links, + outputs=(LiteralMap.from_flyte_idl(pb2_object.outputs) if pb2_object.outputs else None), + custom_info=( + json_format.MessageToDict(pb2_object.custom_info) if pb2_object.HasField("custom_info") else None + ), + ) class AgentBase(ABC): diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index f3f0658286..946bf3a778 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -22,11 +22,20 @@ from flytekit import PythonFunctionTask, task from flytekit.clis.sdk_in_container.serve import print_agents_metadata -from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings +from flytekit.configuration import ( + FastSerializationSettings, + Image, + ImageConfig, + SerializationSettings, +) from flytekit.core.base_task import PythonTask, kwtypes from flytekit.core.interface import Interface from flytekit.exceptions.system import FlyteAgentNotFound -from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService +from flytekit.extend.backend.agent_service import ( + AgentMetadataService, + AsyncAgentService, + SyncAgentService, +) from flytekit.extend.backend.base_agent import ( AgentRegistry, AsyncAgentBase, @@ -71,7 +80,11 @@ def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap return DummyMetadata(job_id=dummy_id) def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: - return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + return Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + custom_info={"custom": "info", "num": 1}, + ) def delete(self, resource_meta: DummyMetadata, **kwargs): ... @@ -96,7 +109,11 @@ async def create( return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name) async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: - return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + return Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + custom_info={"custom": "info", "num": 1}, + ) async def delete(self, resource_meta: DummyMetadata, **kwargs): ... @@ -108,7 +125,12 @@ class MockOpenAIAgent(SyncAgentBase): def __init__(self): super().__init__(task_type_name="openai") - def do(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs) -> Resource: + def do( + self, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + **kwargs, + ) -> Resource: assert inputs.literals["a"].scalar.primitive.integer == 1 return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) @@ -174,6 +196,8 @@ def test_dummy_agent(): assert resource.phase == TaskExecution.SUCCEEDED assert resource.log_links[0].name == "console" assert resource.log_links[0].uri == "localhost:3000" + assert resource.custom_info["custom"] == "info" + assert resource.custom_info["num"] == 1 assert agent.delete(metadata) is None class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): @@ -189,7 +213,9 @@ def __init__(self, **kwargs): @pytest.mark.parametrize( - "agent,consume_metadata", [(DummyAgent(), False), (AsyncDummyAgent(), True)], ids=["sync", "async"] + "agent,consume_metadata", + [(DummyAgent(), False), (AsyncDummyAgent(), True)], + ids=["sync", "async"], ) @pytest.mark.asyncio async def test_async_agent_service(agent, consume_metadata): @@ -222,7 +248,10 @@ async def test_async_agent_service(agent, consume_metadata): assert res.resource_meta == metadata_bytes res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) + res = await service.DeleteTask( + DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), + ctx, + ) assert res == DeleteTaskResponse() agent_metadata = AgentRegistry.get_agent_metadata(agent.name) @@ -269,7 +298,9 @@ def test_openai_agent(): class OpenAITask(SyncAgentExecutorMixin, PythonTask): def __init__(self, **kwargs): super().__init__( - task_type="openai", interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), **kwargs + task_type="openai", + interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), + **kwargs, ) t = OpenAITask(task_config={}, name="openai task") @@ -393,9 +424,54 @@ def test_render_task_template(): @pytest.fixture def sample_agents(): async_agent = Agent( - name="Sensor", is_sync=False, supported_task_categories=[TaskCategory(name="sensor", version=0)] + name="Sensor", + is_sync=False, + supported_task_categories=[TaskCategory(name="sensor", version=0)], ) sync_agent = Agent( - name="ChatGPT Agent", is_sync=True, supported_task_categories=[TaskCategory(name="chatgpt", version=0)] + name="ChatGPT Agent", + is_sync=True, + supported_task_categories=[TaskCategory(name="chatgpt", version=0)], ) return [async_agent, sync_agent] + + +def test_resource_type(): + o = Resource( + phase=TaskExecution.SUCCEEDED, + ) + v = o.to_flyte_idl() + assert v + assert v.phase == TaskExecution.SUCCEEDED + assert len(v.log_links) == 0 + assert v.message == "" + assert len(v.outputs.literals) == 0 + assert len(v.custom_info) == 0 + + o2 = Resource.from_flyte_idl(v) + assert o2 + + o = Resource( + phase=TaskExecution.SUCCEEDED, + log_links=[TaskLog(name="console", uri="localhost:3000")], + message="foo", + outputs={"o0": 1}, + custom_info={"custom": "info", "num": 1}, + ) + v = o.to_flyte_idl() + assert v + assert v.phase == TaskExecution.SUCCEEDED + assert v.log_links[0].name == "console" + assert v.log_links[0].uri == "localhost:3000" + assert v.message == "foo" + assert v.outputs.literals["o0"].scalar.primitive.integer == 1 + assert v.custom_info["custom"] == "info" + assert v.custom_info["num"] == 1 + + o2 = Resource.from_flyte_idl(v) + assert o2.phase == o.phase + assert list(o2.log_links) == list(o.log_links) + assert o2.message == o.message + # round-tripping creates a literal map out of outputs + assert o2.outputs.literals["o0"].scalar.primitive.integer == 1 + assert o2.custom_info == o.custom_info