Skip to content

Commit

Permalink
AIP-72: Add support for fetching XComs in Supervisor (apache#44408)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent d36714b commit 0ef1469
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
19 changes: 19 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariableResponse,
XComResponse,
)
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
Expand Down Expand Up @@ -148,6 +149,18 @@ def get(self, key: str) -> VariableResponse:
return VariableResponse.model_validate_json(resp.read())


class XComOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int = -1) -> XComResponse:
"""Get a XCom value from the API server."""
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params={"map_index": map_index})
return XComResponse.model_validate_json(resp.read())


class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
Expand Down Expand Up @@ -208,6 +221,12 @@ def variables(self) -> VariableOperations:
"""Operations related to Variables."""
return VariableOperations(self)

@lru_cache() # type: ignore[misc]
@property
def xcoms(self) -> XComOperations:
"""Operations related to XComs."""
return XComOperations(self)


# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ class TaskState(BaseModel):

class GetXCom(BaseModel):
key: str
dag_id: str
run_id: str
task_id: str
map_index: int = -1
type: Literal["GetXCom"] = "GetXCom"


Expand Down
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from airflow.sdk.execution_time.comms import (
GetConnection,
GetVariable,
GetXCom,
StartupDetails,
ToSupervisor,
)
Expand Down Expand Up @@ -514,6 +515,9 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
resp = xcom.model_dump_json(exclude_unset=True).encode()
else:
log.error("Unhandled request", msg=msg)
continue
Expand Down
23 changes: 19 additions & 4 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection, GetVariable, VariableResult
from airflow.sdk.execution_time.comms import (
ConnectionResult,
GetConnection,
GetVariable,
GetXCom,
VariableResult,
XComResult,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
from airflow.utils import timezone as tz

Expand Down Expand Up @@ -278,18 +285,26 @@ def watched_subprocess(self, mocker):
GetConnection(conn_id="test_conn"),
b'{"conn_id":"test_conn","conn_type":"mysql"}',
"connections.get",
"test_conn",
("test_conn",),
ConnectionResult(conn_id="test_conn", conn_type="mysql"),
id="get_connection",
),
pytest.param(
GetVariable(key="test_key"),
b'{"key":"test_key","value":"test_value"}',
"variables.get",
"test_key",
("test_key",),
VariableResult(key="test_key", value="test_value"),
id="get_variable",
),
pytest.param(
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value"}',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", -1),
XComResult(key="test_key", value="test_value"),
id="get_xcom",
),
],
)
def test_handle_requests(
Expand Down Expand Up @@ -325,7 +340,7 @@ def test_handle_requests(
generator.send(msg)

# Verify the correct client method was called
mock_client_method.assert_called_once_with(method_arg)
mock_client_method.assert_called_once_with(*method_arg)

# Verify the response was added to the buffer
assert watched_subprocess.stdin.getvalue() == expected_buffer + b"\n"

0 comments on commit 0ef1469

Please sign in to comment.