From b8bd6d9783483658da4b68e000b62a7909e03171 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 13 Dec 2023 04:05:09 +0800 Subject: [PATCH] Support Databricks Agent API 2.1 and Console URL (#1935) Signed-off-by: Future Outlier Co-authored-by: Future Outlier Signed-off-by: Rafael Raposo --- flytekit/models/core/execution.py | 20 ++++++++++++++----- .../flytekitplugins/spark/agent.py | 11 +++++++--- .../flytekitplugins/spark/task.py | 4 +++- plugins/flytekit-spark/tests/test_agent.py | 11 ++++++---- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/flytekit/models/core/execution.py b/flytekit/models/core/execution.py index a7b7c6a5f5..c8ef48fcc5 100644 --- a/flytekit/models/core/execution.py +++ b/flytekit/models/core/execution.py @@ -1,3 +1,6 @@ +import datetime +import typing + from flyteidl.core import execution_pb2 as _execution_pb2 from flytekit.models import common as _common @@ -188,11 +191,17 @@ class MessageFormat(object): CSV = _execution_pb2.TaskLog.CSV JSON = _execution_pb2.TaskLog.JSON - def __init__(self, uri, name, message_format, ttl): + def __init__( + self, + uri: str, + name: str, + message_format: typing.Optional[MessageFormat] = None, + ttl: typing.Optional[datetime.timedelta] = None, + ): """ :param Text uri: :param Text name: - :param int message_format: Enum value from TaskLog.MessageFormat + :param MessageFormat message_format: Enum value from TaskLog.MessageFormat :param datetime.timedelta ttl: The time the log will persist for. 0 represents unknown or ephemeral in nature. """ self._uri = uri @@ -218,7 +227,7 @@ def name(self): def message_format(self): """ Enum value from TaskLog.MessageFormat - :rtype: int + :rtype: MessageFormat """ return self._message_format @@ -234,7 +243,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.execution_pb2.TaskLog """ p = _execution_pb2.TaskLog(uri=self.uri, name=self.name, message_format=self.message_format) - p.ttl.FromTimedelta(self.ttl) + if self.ttl is not None: + p.ttl.FromTimedelta(self.ttl) return p @classmethod @@ -247,5 +257,5 @@ def from_flyte_idl(cls, p): uri=p.uri, name=p.name, message_format=p.message_format, - ttl=p.ttl.ToTimedelta(), + ttl=p.ttl.ToTimedelta() if p.ttl else None, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 053f0eaf20..fcbe276e7d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -10,10 +10,11 @@ from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state, get_agent_secret +from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -DATABRICKS_API_ENDPOINT = "/api/2.0/jobs" +DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" @dataclass @@ -93,7 +94,11 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - if state.get("state_message"): message = state["state_message"] - return GetTaskResponse(resource=Resource(state=cur_state, message=message)) + job_id = response.get("job_id") + databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}" + log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()] + + return GetTaskResponse(resource=Resource(state=cur_state, message=message), log_links=log_links) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = pickle.loads(resource_meta) @@ -103,7 +108,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: - if resp.status != 200: + if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}") await resp.json() diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 5b3e5a135a..6c692fb726 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -48,7 +48,9 @@ class Databricks(Spark): natively onto databricks platform as a distributed execution of spark Args: - databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases. + For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 823824a15b..7effe0b9e5 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -1,3 +1,4 @@ +import http import pickle from datetime import timedelta from unittest import mock @@ -113,23 +114,25 @@ async def test_databricks_agent(): ) mock_create_response = {"run_id": "123"} - mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}} + mock_get_response = {"job_id": "1", "run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}} mock_delete_response = {} create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123" delete_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel" with aioresponses() as mocked: - mocked.post(create_url, status=200, payload=mock_create_response) + mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response) res = await agent.async_create(ctx, "/tmp", dummy_template, None) assert res.resource_meta == metadata_bytes - mocked.get(get_url, status=200, payload=mock_get_response) + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) res = await agent.async_get(ctx, metadata_bytes) assert res.resource.state == SUCCEEDED assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() assert res.resource.message == "OK" + assert res.log_links[0].name == "Databricks Console" + assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" - mocked.post(delete_url, status=200, payload=mock_delete_response) + mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) await agent.async_delete(ctx, metadata_bytes) assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"}