Skip to content

Commit

Permalink
[CHORE] review changes, getting all failed task errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav7261 committed Apr 30, 2024
1 parent 08e6d3c commit 3f6d898
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 26 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
validate_trigger_event(event)
run_state = RunState.from_json(event["run_state"])
run_page_url = event["run_page_url"]
notebook_error = event["notebook_error"]
errors = event["errors"]
log.info("View run status, Spark UI, and logs at %s", run_page_url)

if run_state.is_successful:
log.info("Job run completed successfully.")
return

error_message = f"Job run failed with terminal state: {run_state} and with the error {notebook_error}"
error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}"

if event["repair_run"]:
log.warning(
Expand Down
20 changes: 9 additions & 11 deletions airflow/providers/databricks/triggers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ async def run(self):
async with self.hook:
while True:
run_state = await self.hook.a_get_run_state(self.run_id)
notebook_error = None
if not run_state.is_terminal:
self.log.info(
"run-id %s in run state %s. sleeping for %s seconds",
Expand All @@ -95,27 +94,26 @@ async def run(self):
await asyncio.sleep(self.polling_period_seconds)
continue

failed_tasks = []
if run_state.result_state == "FAILED":
run_info = await self.hook.a_get_run(self.run_id)
task_run_id = None
for task in run_info.get("tasks", []):
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = await self.hook.a_get_run_output(task_run_id)
if "error" in run_output:
notebook_error = run_output["error"]
else:
notebook_error = run_state.state_message
else:
notebook_error = run_state.state_message
task_key = task["task_key"]
run_output = await self.hook.a_get_run_output(task_run_id)
if "error" in run_output:
error = run_output["error"]
else:
error = run_state.state_message
failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error})
yield TriggerEvent(
{
"run_id": self.run_id,
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
"notebook_error": notebook_error,
"errors": failed_tasks,
}
)
return
2 changes: 1 addition & 1 deletion airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validate_trigger_event(event: dict):
See: :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`.
"""
keys_to_check = ["run_id", "run_page_url", "run_state", "notebook_error"]
keys_to_check = ["run_id", "run_page_url", "run_state", "errors"]
for key in keys_to_check:
if key not in event:
raise AirflowException(f"Could not find `{key}` in the event: {event}")
Expand Down
10 changes: 5 additions & 5 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def test_execute_complete_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"notebook_error": None,
"errors": [],
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand All @@ -1045,7 +1045,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"notebook_error": None,
"errors": [],
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand Down Expand Up @@ -1596,7 +1596,7 @@ def test_execute_complete_success(self):
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"repair_run": False,
"notebook_error": None,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand All @@ -1614,7 +1614,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"notebook_error": None,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down Expand Up @@ -1645,7 +1645,7 @@ def test_execute_complete_failure_and_repair_run(
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
"notebook_error": None,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down
39 changes: 33 additions & 6 deletions tests/providers/databricks/triggers/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
RETRY_DELAY = 10
RETRY_LIMIT = 3
RUN_ID = 1
TASK_RUN_ID = 11
TASK_RUN_ID1 = 11
TASK_RUN_ID1_KEY = "first_task"
TASK_RUN_ID2 = 22
TASK_RUN_ID2_KEY = "second_task"
TASK_RUN_ID3 = 33
TASK_RUN_ID3_KEY = "third_task"
JOB_ID = 42
RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1"
ERROR_MESSAGE = "error message from databricks API"
Expand Down Expand Up @@ -80,13 +85,32 @@
},
"tasks": [
{
"run_id": TASK_RUN_ID,
"run_id": TASK_RUN_ID1,
"task_key": TASK_RUN_ID1_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
}
},
{
"run_id": TASK_RUN_ID2,
"task_key": TASK_RUN_ID2_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "SUCCESS",
"state_message": None,
},
},
{
"run_id": TASK_RUN_ID3,
"task_key": TASK_RUN_ID3_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
},
],
}

Expand Down Expand Up @@ -150,7 +174,7 @@ async def test_run_return_success(
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"notebook_error": None,
"errors": [],
}
)

Expand Down Expand Up @@ -181,7 +205,10 @@ async def test_run_return_failure(
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"notebook_error": ERROR_MESSAGE,
"errors": [
{"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE},
{"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE},
],
}
)

Expand Down Expand Up @@ -218,7 +245,7 @@ async def test_sleep_between_retries(
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"notebook_error": None,
"errors": [],
}
)
mock_sleep.assert_called_once()
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/databricks/utils/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_validate_trigger_event_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"notebook_error": None,
"errors": [],
}
assert validate_trigger_event(event) is None

Expand Down

0 comments on commit 3f6d898

Please sign in to comment.