Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent Metadata Servicer #2012

Merged
merged 34 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bfa8649
AgentMetadataService
Nov 30, 2023
6662e73
add supported task type
Dec 4, 2023
69be4c4
list agent
Dec 4, 2023
94e6998
reformatted
Dec 4, 2023
a670fd2
pr example
Dec 4, 2023
95a18df
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Dec 10, 2023
7bc6b1b
agent hasattr is_sync
Dec 10, 2023
68ad44f
remove print
Dec 10, 2023
df2bf00
change to supported task types
Dec 12, 2023
f80ece1
update proto
Dec 19, 2023
9802aca
agent metadata registry
Dec 19, 2023
ffe91a1
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Dec 19, 2023
95bd27e
metadata log
Dec 19, 2023
3ee34d9
Update plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py
Future-Outlier Dec 19, 2023
5f5a62c
Update flytekit/extend/backend/base_agent.py
Future-Outlier Dec 19, 2023
0ed7f2d
Update flytekit/extend/backend/base_agent.py
Future-Outlier Dec 19, 2023
26ad2ba
Update plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py
Future-Outlier Dec 19, 2023
bf7cf48
Update flytekit/extend/backend/base_agent.py
Future-Outlier Dec 19, 2023
cd50544
Update flytekit/extend/backend/base_agent.py
Future-Outlier Dec 19, 2023
d760b23
lint
Dec 19, 2023
f64308c
change to ListAgents method
Dec 20, 2023
5b5067a
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Jan 9, 2024
4f52e0d
install latest pre release flyteidl
Jan 9, 2024
c72c926
Merge branch 'master' of https://github.com/Future-Outlier/flytekit i…
Jan 11, 2024
be6d0f6
Merge branch 'master' of https://github.com/flyteorg/flytekit into ag…
Jan 13, 2024
248a9ce
refactor
Jan 13, 2024
5d8929e
update dev requirements
Jan 13, 2024
19f0be8
latest idl in Dockerfile.dev
Jan 14, 2024
99bb5a8
nit
Jan 14, 2024
150990f
lint
Jan 16, 2024
7646d41
Merge branch 'master' into agent-metadata-proto-service
Future-Outlier Jan 18, 2024
77cb562
class attribute
pingsutw Jan 31, 2024
c5ae30f
nit
pingsutw Jan 31, 2024
e7afa5d
fix tests
Future-Outlier Jan 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from concurrent import futures

import rich_click as click
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server
from flyteidl.service.agent_pb2_grpc import (
add_AgentMetadataServiceServicer_to_server,
add_AsyncAgentServiceServicer_to_server,
)
from grpc import aio


Expand Down Expand Up @@ -49,7 +52,7 @@ def agent(_: click.Context, port, worker, timeout):

async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting up the server to expose the prometheus metrics...", fg="blue")
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService

try:
from prometheus_client import start_http_server
Expand All @@ -61,6 +64,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)
add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server)

server.add_insecure_port(f"[::]:{port}")
await server.start()
Expand Down
31 changes: 25 additions & 6 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
GetAgentRequest,
GetAgentResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
)
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer
from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer
from prometheus_client import Counter, Summary

from flytekit import logger
Expand All @@ -26,18 +30,20 @@

# Follow the naming convention. https://prometheus.io/docs/practices/naming/
request_success_count = Counter(
f"{metric_prefix}requests_success_total", "Total number of successful requests", ["task_type", "operation"]
f"{metric_prefix}requests_success_total",
"Total number of successful requests",
["task_type", "operation"],
)
request_failure_count = Counter(
f"{metric_prefix}requests_failure_total",
"Total number of failed requests",
["task_type", "operation", "error_code"],
)

request_latency = Summary(
f"{metric_prefix}request_latency_seconds", "Time spent processing agent request", ["task_type", "operation"]
f"{metric_prefix}request_latency_seconds",
"Time spent processing agent request",
["task_type", "operation"],
)

input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])


Expand Down Expand Up @@ -96,8 +102,12 @@
logger.info(f"{tmp.type} agent start creating the job")
if agent.asynchronous:
return await agent.async_create(
context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
context=context,
inputs=inputs,
output_prefix=request.output_prefix,
task_template=tmp,
)

return await asyncio.get_running_loop().run_in_executor(
None,
agent.create,
Expand All @@ -122,3 +132,12 @@
if agent.asynchronous:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta)


class AgentMetadataService(AgentMetadataServiceServicer):
async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse:
return GetAgentResponse(agent=AgentRegistry._METADATA[request.name])

Check warning on line 139 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L139

Added line #L139 was not covered by tests

async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse:
agents = [agent for agent in AgentRegistry._METADATA.values()]
return ListAgentsResponse(agents=agents)

Check warning on line 143 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L143

Added line #L143 was not covered by tests
27 changes: 23 additions & 4 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
Agent,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Expand Down Expand Up @@ -45,7 +46,9 @@
will look up the agent based on the task type. Every task type can only have one agent.
"""

def __init__(self, task_type: str, asynchronous=True):
name = "Base Agent"

def __init__(self, task_type: str, asynchronous: bool = True):
self._task_type = task_type
self._asynchronous = asynchronous

Expand Down Expand Up @@ -113,25 +116,41 @@

class AgentRegistry(object):
"""
This is the registry for all agents. The agent service will look up the agent
based on the task type.
This is the registry for all agents.
The agent service will look up the agent registry based on the task type.
The agent metadata service will look up the agent metadata based on the agent name.
"""

_REGISTRY: typing.Dict[str, AgentBase] = {}
_METADATA: typing.Dict[str, Agent] = {}

@staticmethod
def register(agent: AgentBase):
if agent.task_type in AgentRegistry._REGISTRY:
raise ValueError(f"Duplicate agent for task type {agent.task_type}")
AgentRegistry._REGISTRY[agent.task_type] = agent
logger.info(f"Registering an agent for task type {agent.task_type}")

if agent.name in AgentRegistry._METADATA:
agent_metadata = AgentRegistry._METADATA[agent.name]
agent_metadata.supported_task_types.append(agent.task_type)

Check warning on line 135 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L134-L135

Added lines #L134 - L135 were not covered by tests
else:
agent_metadata = Agent(name=agent.name, supported_task_types=[agent.task_type])
AgentRegistry._METADATA[agent.name] = agent_metadata

logger.info(f"Registering an agent for task type: {agent.task_type}, name: {agent.name}")

@staticmethod
def get_agent(task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.")
return AgentRegistry._REGISTRY[task_type]

@staticmethod
def get_agent_metadata(name: str) -> Agent:
if name not in AgentRegistry._METADATA:
raise FlyteAgentNotFound(f"Cannot find agent for name: {name}.")

Check warning on line 151 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L151

Added line #L151 was not covered by tests
return AgentRegistry._METADATA[name]


def convert_to_flyte_state(state: str) -> State:
"""
Expand Down
2 changes: 2 additions & 0 deletions flytekit/sensor/sensor_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


class SensorEngine(AgentBase):
name = "Sensor"

def __init__(self):
super().__init__(task_type="sensor", asynchronous=True)

Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-airflow/flytekitplugins/airflow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class AirflowAgent(AgentBase):
In this case, those operators will be converted to AirflowContainerTask and executed in the pod.
"""

name = "Airflow Agent"

def __init__(self):
super().__init__(task_type="airflow", asynchronous=True)

Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Metadata:


class BigQueryAgent(AgentBase):
name = "Bigquery Agent"

def __init__(self):
super().__init__(task_type="bigquery_query_job_task", asynchronous=False)

Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class Metadata:


class MMCloudAgent(AgentBase):
name = "MMCloud Agent"

def __init__(self):
super().__init__(task_type="mmcloud_task")
super().__init__(task_type="mmcloud_task", asynchronous=True)
self._response_format = ["--format", "json"]

async def async_login(self):
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-papermill/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
flyteidl>=1.10.0
flyteidl>=1.10.7b0
-e file:../../.#egg=flytekitplugins-pod&subdirectory=plugins/flytekit-k8s-pod
-e file:../../.#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
-e file:../../.#egg=flytekitplugins-awsbatch&subdirectory=plugins/flytekit-aws-batch
6 changes: 4 additions & 2 deletions plugins/flytekit-papermill/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.10.6
flyteidl==1.10.7b0
# via
# -r dev-requirements.in
# flytekit
Expand Down Expand Up @@ -235,7 +235,9 @@ packaging==23.2
# docker
# marshmallow
pandas==1.5.3
# via flytekit
# via
# flytekit
# flytekitplugins-spark
portalocker==2.8.2
# via msal-extensions
protobuf==4.24.4
Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class Metadata:


class DatabricksAgent(AgentBase):
name = "Databricks Agent"

def __init__(self):
super().__init__(task_type="spark")
super().__init__(task_type="spark", asynchronous=True)

async def async_create(
self,
Expand Down
26 changes: 25 additions & 1 deletion tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)

from flytekit import PythonFunctionTask, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService
from flytekit.extend.backend.base_agent import (
AgentBase,
AgentRegistry,
Expand All @@ -49,6 +51,8 @@ class Metadata:


class DummyAgent(AgentBase):
name = "Dummy Agent"

def __init__(self):
super().__init__(task_type="dummy", asynchronous=False)

Expand All @@ -71,6 +75,8 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT


class AsyncDummyAgent(AgentBase):
name = "Async Dummy Agent"

def __init__(self):
super().__init__(task_type="async_dummy", asynchronous=True)

Expand All @@ -91,6 +97,8 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes


class SyncDummyAgent(AgentBase):
name = "Sync Dummy Agent"

def __init__(self):
super().__init__(task_type="sync_dummy", asynchronous=True)

Expand Down Expand Up @@ -161,6 +169,10 @@ def __init__(self, **kwargs):
with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."):
t.execute()

agent_metadata = AgentRegistry.get_agent_metadata("Dummy Agent")
assert agent_metadata.name == "Dummy Agent"
assert agent_metadata.supported_task_types == ["dummy"]


@pytest.mark.asyncio
async def test_async_dummy_agent():
Expand All @@ -175,6 +187,10 @@ async def test_async_dummy_agent():
res = await agent.async_delete(ctx, metadata_bytes)
assert res == DeleteTaskResponse()

agent_metadata = AgentRegistry.get_agent_metadata("Async Dummy Agent")
assert agent_metadata.name == "Async Dummy Agent"
assert agent_metadata.supported_task_types == ["async_dummy"]


@pytest.mark.asyncio
async def test_sync_dummy_agent():
Expand All @@ -185,6 +201,10 @@ async def test_sync_dummy_agent():
assert res.resource.state == SUCCEEDED
assert res.resource.outputs == LiteralMap({}).to_flyte_idl()

agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent")
assert agent_metadata.name == "Sync Dummy Agent"
assert agent_metadata.supported_task_types == ["sync_dummy"]


@pytest.mark.asyncio
async def run_agent_server():
Expand Down Expand Up @@ -223,6 +243,10 @@ async def run_agent_server():
res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx)
assert res is None

metadata_service = AgentMetadataService()
res = await metadata_service.ListAgent(ListAgentsRequest(), ctx)
assert isinstance(res, ListAgentsResponse)


def test_agent_server():
loop.run_in_executor(None, run_agent_server)
Expand Down
Loading