Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] raise exception with main notebook error in DatabricksRunNowDeferrableOperator #39110

Merged
merged 12 commits into from
May 1, 2024
Merged
11 changes: 11 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def get_run_output(self, run_id: int) -> dict:
run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output

async def a_get_run_output(self, run_id: int) -> dict:
"""
Async version of `get_run_output()`.

:param run_id: id of the run
:return: output of the run
"""
json = {"run_id": run_id}
run_output = await self._a_do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output

def cancel_run(self, run_id: int) -> None:
"""
Cancel the run.
Expand Down
11 changes: 6 additions & 5 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:

if run_state.result_state == "FAILED":
task_run_id = None
if "tasks" in run_info:
for task in run_info["tasks"]:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
for task in run_info.get("tasks", []):
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if "error" in run_output:
Expand Down Expand Up @@ -160,13 +159,15 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
validate_trigger_event(event)
run_state = RunState.from_json(event["run_state"])
run_page_url = event["run_page_url"]
errors = event["errors"]
log.info("View run status, Spark UI, and logs at %s", run_page_url)

if run_state.is_successful:
log.info("Job run completed successfully.")
return

error_message = f"Job run failed with terminal state: {run_state}"
error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}"

if event["repair_run"]:
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
Expand Down
45 changes: 30 additions & 15 deletions airflow/providers/databricks/triggers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,36 @@ async def run(self):
async with self.hook:
while True:
run_state = await self.hook.a_get_run_state(self.run_id)
if run_state.is_terminal:
yield TriggerEvent(
{
"run_id": self.run_id,
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
}
if not run_state.is_terminal:
self.log.info(
"run-id %s in run state %s. sleeping for %s seconds",
self.run_id,
run_state,
self.polling_period_seconds,
)
return
await asyncio.sleep(self.polling_period_seconds)
continue

self.log.info(
"run-id %s in run state %s. sleeping for %s seconds",
self.run_id,
run_state,
self.polling_period_seconds,
failed_tasks = []
if run_state.result_state == "FAILED":
run_info = await self.hook.a_get_run(self.run_id)
for task in run_info.get("tasks", []):
gaurav7261 marked this conversation as resolved.
Show resolved Hide resolved
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
task_key = task["task_key"]
run_output = await self.hook.a_get_run_output(task_run_id)
if "error" in run_output:
error = run_output["error"]
else:
error = run_state.state_message
failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error})
yield TriggerEvent(
{
"run_id": self.run_id,
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
"errors": failed_tasks,
}
)
await asyncio.sleep(self.polling_period_seconds)
return
2 changes: 1 addition & 1 deletion airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validate_trigger_event(event: dict):

See: :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`.
"""
keys_to_check = ["run_id", "run_page_url", "run_state"]
keys_to_check = ["run_id", "run_page_url", "run_state", "errors"]
for key in keys_to_check:
if key not in event:
raise AirflowException(f"Could not find `{key}` in the event: {event}")
Expand Down
17 changes: 17 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,23 @@ async def test_get_cluster_state(self, mock_get):
timeout=self.hook.timeout_seconds,
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_get_run_output(self, mock_get):
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_OUTPUT_RESPONSE)
async with self.hook:
run_output = await self.hook.a_get_run_output(RUN_ID)
run_output_error = run_output.get("error")

assert run_output_error == ERROR_MESSAGE
mock_get.assert_called_once_with(
get_run_output_endpoint(HOST),
json={"run_id": RUN_ID},
auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)


@pytest.mark.db_test
class TestDatabricksHookAsyncAadToken:
Expand Down
5 changes: 5 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def test_execute_complete_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"errors": [],
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand All @@ -1044,6 +1045,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"errors": [],
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand Down Expand Up @@ -1594,6 +1596,7 @@ def test_execute_complete_success(self):
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"repair_run": False,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand All @@ -1611,6 +1614,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down Expand Up @@ -1641,6 +1645,7 @@ def test_execute_complete_failure_and_repair_run(
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down
99 changes: 97 additions & 2 deletions tests/providers/databricks/triggers/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,22 @@
RETRY_DELAY = 10
RETRY_LIMIT = 3
RUN_ID = 1
TASK_RUN_ID1 = 11
TASK_RUN_ID1_KEY = "first_task"
TASK_RUN_ID2 = 22
TASK_RUN_ID2_KEY = "second_task"
TASK_RUN_ID3 = 33
TASK_RUN_ID3_KEY = "third_task"
JOB_ID = 42
RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1"
ERROR_MESSAGE = "error message from databricks API"
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}}

RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"]

LIFE_CYCLE_STATE_PENDING = "PENDING"
LIFE_CYCLE_STATE_TERMINATED = "TERMINATED"
LIFE_CYCLE_STATE_INTERNAL_ERROR = "INTERNAL_ERROR"

STATE_MESSAGE = "Waiting for cluster"

Expand All @@ -66,6 +75,44 @@
"result_state": "SUCCESS",
},
}
GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = {
"job_id": JOB_ID,
"run_page_url": RUN_PAGE_URL,
"state": {
"life_cycle_state": LIFE_CYCLE_STATE_INTERNAL_ERROR,
"state_message": None,
"result_state": "FAILED",
},
"tasks": [
{
"run_id": TASK_RUN_ID1,
"task_key": TASK_RUN_ID1_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
},
{
"run_id": TASK_RUN_ID2,
"task_key": TASK_RUN_ID2_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "SUCCESS",
"state_message": None,
},
},
{
"run_id": TASK_RUN_ID3,
"task_key": TASK_RUN_ID3_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
},
],
}


class TestDatabricksExecutionTrigger:
Expand Down Expand Up @@ -101,15 +148,21 @@ def test_serialize(self):
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_url):
async def test_run_return_success(
self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output
):
mock_get_run_page_url.return_value = RUN_PAGE_URL
mock_get_run_state.return_value = RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message="",
result_state="SUCCESS",
)
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE

trigger_event = self.trigger.run()
async for event in trigger_event:
Expand All @@ -121,13 +174,52 @@ async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_ur
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"errors": [],
}
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_run_return_failure(
self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output
):
mock_get_run_page_url.return_value = RUN_PAGE_URL
mock_get_run_state.return_value = RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message="",
result_state="FAILED",
)
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED

trigger_event = self.trigger.run()
async for event in trigger_event:
assert event == TriggerEvent(
{
"run_id": RUN_ID,
"run_state": RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED"
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"errors": [
{"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE},
{"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE},
],
}
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
async def test_sleep_between_retries(
self, mock_get_run_state, mock_sleep, mock_get_run, mock_get_run_output
):
mock_get_run_state.side_effect = [
RunState(
life_cycle_state=LIFE_CYCLE_STATE_PENDING,
Expand All @@ -140,6 +232,8 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
result_state="SUCCESS",
),
]
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE

trigger_event = self.trigger.run()
async for event in trigger_event:
Expand All @@ -151,6 +245,7 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"errors": [],
}
)
mock_sleep.assert_called_once()
Expand Down
1 change: 1 addition & 0 deletions tests/providers/databricks/utils/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_validate_trigger_event_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"errors": [],
}
assert validate_trigger_event(event) is None

Expand Down