diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 3e1527c5c5..eb2838ca41 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -30,7 +30,7 @@ from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate +from flytekit.models.task import TaskExecutionMetadata, TaskTemplate metric_prefix = "flyte_agent_" create_operation = "create" @@ -115,6 +115,7 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon template = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(template.type, template.task_type_version) + task_execution_metadata = TaskExecutionMetadata.from_flyte_idl(request.task_execution_metadata) logger.info(f"{agent.name} start creating the job") resource_mata = await mirror_async_methods( @@ -122,6 +123,7 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon task_template=template, inputs=inputs, output_prefix=request.output_prefix, + task_execution_metadata=task_execution_metadata, ) return CreateTaskResponse(resource_meta=resource_mata.encode()) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 3c1a149abc..ac942a3642 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -26,7 +26,7 @@ from flytekit.exceptions.user import FlyteUserException from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate +from flytekit.models.task import TaskExecutionMetadata, TaskTemplate class TaskCategory: @@ -146,7 +146,12 @@ def metadata_type(self) -> ResourceMeta: @abstractmethod def create( - self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: Optional[str], **kwargs + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap], + output_prefix: Optional[str], + task_execution_metadata: Optional[TaskExecutionMetadata], + **kwargs, ) -> ResourceMeta: """ Return a resource meta that can be used to get the status of the task. diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 5199536b5d..dcea3e6b34 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -1,4 +1,5 @@ import asyncio +import functools import inspect from typing import Callable, Coroutine @@ -11,8 +12,7 @@ def mirror_async_methods(func: Callable, **kwargs) -> Coroutine: if inspect.iscoroutinefunction(func): return func(**kwargs) - args = [v for _, v in kwargs.items()] - return asyncio.get_running_loop().run_in_executor(None, func, *args) + return asyncio.get_running_loop().run_in_executor(None, functools.partial(func, **kwargs)) def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: diff --git a/flytekit/models/task.py b/flytekit/models/task.py index b6e8222fb9..198adf2859 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -1,6 +1,7 @@ import json as _json import typing +from flyteidl.admin import agent_pb2 as _admin_agent from flyteidl.admin import task_pb2 as _admin_task from flyteidl.core import compiler_pb2 as _compiler from flyteidl.core import literals_pb2 as _literals_pb2 @@ -518,6 +519,94 @@ def from_flyte_idl(cls, pb2_object): ) +class TaskExecutionMetadata(_common.FlyteIdlEntity): + def __init__( + self, + task_execution_id, + namespace, + labels, + annotations, + k8s_service_account, + environment_variables, + ): + """ + Runtime task execution metadata. + + :param flytekit.models.core.identifier.TaskExecutionIdentifier task_execution_id: This is generated by the system and uniquely identifies + this execution of the task. + :param Text namespace: This is the namespace the task is executing in. + :param dict[str, str] labels: Labels to use for the execution of this task. + :param dict[str, str] annotations: Annotations to use for the execution of this task. + :param Text k8s_service_account: Service account to use for execution of this task. + :param dict[str, str] environment_variables: Environment variables for this task. + """ + self._task_execution_id = task_execution_id + self._namespace = namespace + self._labels = labels + self._annotations = annotations + self._k8s_service_account = k8s_service_account + self._environment_variables = environment_variables + + @property + def task_execution_id(self): + return self._task_execution_id + + @property + def namespace(self): + return self._namespace + + @property + def labels(self): + return self._labels + + @property + def annotations(self): + return self._annotations + + @property + def k8s_service_account(self): + return self._k8s_service_account + + @property + def environment_variables(self): + return self._environment_variables + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata + """ + task_execution_metadata = _admin_agent.TaskExecutionMetadata( + task_execution_id=self.task_execution_id.to_flyte_idl(), + namespace=self.namespace, + labels={k: v for k, v in self.labels.items()} if self.labels is not None else None, + annotations={k: v for k, v in self.annotations.items()} if self.annotations is not None else None, + k8s_service_account=self.k8s_service_account, + environment_variables={k: v for k, v in self.environment_variables.items()} + if self.labels is not None + else None, + ) + return task_execution_metadata + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.agent_pb2.TaskExecutionMetadata pb2_object: + :rtype: TaskExecutionMetadata + """ + return cls( + task_execution_id=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb2_object.task_execution_id), + namespace=pb2_object.namespace, + labels={k: v for k, v in pb2_object.labels.items()} if pb2_object.labels is not None else None, + annotations={k: v for k, v in pb2_object.annotations.items()} + if pb2_object.annotations is not None + else None, + k8s_service_account=pb2_object.k8s_service_account, + environment_variables={k: v for k, v in pb2_object.environment_variables.items()} + if pb2_object.environment_variables is not None + else None, + ) + + class TaskSpec(_common.FlyteIdlEntity): def __init__(self, template: TaskTemplate, docs: typing.Optional[Documentation] = None): """ diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 1bb4976cbd..2bf23abb25 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -17,6 +17,7 @@ TaskCategory, ) from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from flyteidl.core.identifier_pb2 import ResourceType from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings @@ -37,8 +38,14 @@ ) from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models import literals +from flytekit.models.core.identifier import ( + Identifier, + NodeExecutionIdentifier, + TaskExecutionIdentifier, + WorkflowExecutionIdentifier, +) from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate +from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable dummy_id = "dummy_id" @@ -48,6 +55,7 @@ class DummyMetadata(ResourceMeta): job_id: str output_path: typing.Optional[str] = None + task_name: typing.Optional[str] = None class DummyAgent(AsyncAgentBase): @@ -77,10 +85,12 @@ async def create( task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, output_prefix: typing.Optional[str] = None, + task_execution_metadata: typing.Optional[TaskExecutionMetadata] = None, **kwargs, ) -> DummyMetadata: output_path = f"{output_prefix}/{dummy_id}" if output_prefix else None - return DummyMetadata(job_id=dummy_id, output_path=output_path) + task_name = task_execution_metadata.task_execution_id.task_id.name if task_execution_metadata else "default" + return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name) async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) @@ -136,6 +146,19 @@ def simple_task(i: int): }, ) +task_execution_metadata = TaskExecutionMetadata( + task_execution_id=TaskExecutionIdentifier( + task_id=Identifier(ResourceType.TASK, "project", "domain", "name", "version"), + node_execution_id=NodeExecutionIdentifier("node_id", WorkflowExecutionIdentifier("project", "domain", "name")), + retry_attempt=1, + ), + namespace="namespace", + labels={"label_key": "label_val"}, + annotations={"annotation_key": "annotation_val"}, + k8s_service_account="k8s service account", + environment_variables={"env_var_key": "env_var_val"}, +) + def test_dummy_agent(): AgentRegistry.register(DummyAgent(), override=True) @@ -161,20 +184,35 @@ def __init__(self, **kwargs): t.execute() -@pytest.mark.parametrize("agent", [DummyAgent(), AsyncDummyAgent()], ids=["sync", "async"]) +@pytest.mark.parametrize( + "agent,consume_metadata", [(DummyAgent(), False), (AsyncDummyAgent(), True)], ids=["sync", "async"] +) @pytest.mark.asyncio -async def test_async_agent_service(agent): +async def test_async_agent_service(agent, consume_metadata): AgentRegistry.register(agent, override=True) service = AsyncAgentService() ctx = MagicMock(spec=grpc.ServicerContext) inputs_proto = task_inputs.to_flyte_idl() output_prefix = "/tmp" - metadata_bytes = DummyMetadata(job_id=dummy_id, output_path=f"{output_prefix}/{dummy_id}").encode() + metadata_bytes = ( + DummyMetadata( + job_id=dummy_id, + output_path=f"{output_prefix}/{dummy_id}", + task_name=task_execution_metadata.task_execution_id.task_id.name, + ).encode() + if consume_metadata + else DummyMetadata(job_id=dummy_id).encode() + ) tmp = get_task_template(agent.task_category.name).to_flyte_idl() task_category = TaskCategory(name=agent.task_category.name, version=0) - req = CreateTaskRequest(inputs=inputs_proto, output_prefix=output_prefix, template=tmp) + req = CreateTaskRequest( + inputs=inputs_proto, + template=tmp, + output_prefix=output_prefix, + task_execution_metadata=task_execution_metadata.to_flyte_idl(), + ) res = await service.CreateTask(req, ctx) assert res.resource_meta == metadata_bytes