Skip to content

Commit

Permalink
Support Databricks Agent API 2.1 and Console URL (#1935)
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
  • Loading branch information
Future-Outlier and Future Outlier authored Dec 12, 2023
1 parent 492997e commit 935778f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
20 changes: 15 additions & 5 deletions flytekit/models/core/execution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import datetime
import typing

from flyteidl.core import execution_pb2 as _execution_pb2

from flytekit.models import common as _common
Expand Down Expand Up @@ -188,11 +191,17 @@ class MessageFormat(object):
CSV = _execution_pb2.TaskLog.CSV
JSON = _execution_pb2.TaskLog.JSON

def __init__(self, uri, name, message_format, ttl):
def __init__(
self,
uri: str,
name: str,
message_format: typing.Optional[MessageFormat] = None,
ttl: typing.Optional[datetime.timedelta] = None,
):
"""
:param Text uri:
:param Text name:
:param int message_format: Enum value from TaskLog.MessageFormat
:param MessageFormat message_format: Enum value from TaskLog.MessageFormat
:param datetime.timedelta ttl: The time the log will persist for. 0 represents unknown or ephemeral in nature.
"""
self._uri = uri
Expand All @@ -218,7 +227,7 @@ def name(self):
def message_format(self):
"""
Enum value from TaskLog.MessageFormat
:rtype: int
:rtype: MessageFormat
"""
return self._message_format

Expand All @@ -234,7 +243,8 @@ def to_flyte_idl(self):
:rtype: flyteidl.core.execution_pb2.TaskLog
"""
p = _execution_pb2.TaskLog(uri=self.uri, name=self.name, message_format=self.message_format)
p.ttl.FromTimedelta(self.ttl)
if self.ttl is not None:
p.ttl.FromTimedelta(self.ttl)
return p

@classmethod
Expand All @@ -247,5 +257,5 @@ def from_flyte_idl(cls, p):
uri=p.uri,
name=p.name,
message_format=p.message_format,
ttl=p.ttl.ToTimedelta(),
ttl=p.ttl.ToTimedelta() if p.ttl else None,
)
11 changes: 8 additions & 3 deletions plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource

from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state, get_agent_secret
from flytekit.models.core.execution import TaskLog
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

DATABRICKS_API_ENDPOINT = "/api/2.0/jobs"
DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"


@dataclass
Expand Down Expand Up @@ -93,7 +94,11 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -
if state.get("state_message"):
message = state["state_message"]

return GetTaskResponse(resource=Resource(state=cur_state, message=message))
job_id = response.get("job_id")
databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}"
log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()]

return GetTaskResponse(resource=Resource(state=cur_state, message=message), log_links=log_links)

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
metadata = pickle.loads(resource_meta)
Expand All @@ -103,7 +108,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes

async with aiohttp.ClientSession() as session:
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
if resp.status != 200:
if resp.status != http.HTTPStatus.OK:
raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}")
await resp.json()

Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class Databricks(Spark):
natively onto databricks platform as a distributed execution of spark
Args:
databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases.
For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html
databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html.
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""
Expand Down
11 changes: 7 additions & 4 deletions plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import http
import pickle
from datetime import timedelta
from unittest import mock
Expand Down Expand Up @@ -113,23 +114,25 @@ async def test_databricks_agent():
)

mock_create_response = {"run_id": "123"}
mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}}
mock_get_response = {"job_id": "1", "run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}}
mock_delete_response = {}
create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit"
get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123"
delete_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel"
with aioresponses() as mocked:
mocked.post(create_url, status=200, payload=mock_create_response)
mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response)
res = await agent.async_create(ctx, "/tmp", dummy_template, None)
assert res.resource_meta == metadata_bytes

mocked.get(get_url, status=200, payload=mock_get_response)
mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response)
res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl()
assert res.resource.message == "OK"
assert res.log_links[0].name == "Databricks Console"
assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123"

mocked.post(delete_url, status=200, payload=mock_delete_response)
mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response)
await agent.async_delete(ctx, metadata_bytes)

assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"}
Expand Down

0 comments on commit 935778f

Please sign in to comment.