Skip to content

Commit

Permalink
[FIX] tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav7261 committed Apr 22, 2024
1 parent 1b08b98 commit f406db4
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 6 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ async def a_get_run_output(self, run_id: int) -> dict:
:return: output of the run
"""
json = {"run_id": run_id}
run_output = await self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
run_output = await self._a_do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output

def cancel_run(self, run_id: int) -> None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/triggers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def run(self):
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
"notebook_error": notebook_error
"notebook_error": notebook_error,
}
)
return
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,7 +1614,7 @@ async def test_get_run_output(self, mock_get):
assert run_output_error == ERROR_MESSAGE
mock_get.assert_called_once_with(
get_run_output_endpoint(HOST),
json=None,
json={"run_id": RUN_ID},
auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
Expand Down
5 changes: 5 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ def test_execute_complete_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"notebook_error": None,
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand All @@ -1033,6 +1034,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"notebook_error": None,
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand Down Expand Up @@ -1583,6 +1585,7 @@ def test_execute_complete_success(self):
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"repair_run": False,
"notebook_error": None,
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand All @@ -1600,6 +1603,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"notebook_error": None,
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down Expand Up @@ -1630,6 +1634,7 @@ def test_execute_complete_failure_and_repair_run(
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
"notebook_error": None,
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down
72 changes: 70 additions & 2 deletions tests/providers/databricks/triggers/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@
RETRY_DELAY = 10
RETRY_LIMIT = 3
RUN_ID = 1
TASK_RUN_ID = 11
JOB_ID = 42
RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1"
ERROR_MESSAGE = "error message from databricks API"
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}}

RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"]

LIFE_CYCLE_STATE_PENDING = "PENDING"
LIFE_CYCLE_STATE_TERMINATED = "TERMINATED"
LIFE_CYCLE_STATE_INTERNAL_ERROR = "INTERNAL_ERROR"

STATE_MESSAGE = "Waiting for cluster"

Expand All @@ -66,6 +70,25 @@
"result_state": "SUCCESS",
},
}
GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = {
"job_id": JOB_ID,
"run_page_url": RUN_PAGE_URL,
"state": {
"life_cycle_state": LIFE_CYCLE_STATE_INTERNAL_ERROR,
"state_message": None,
"result_state": "FAILED",
},
"tasks": [
{
"run_id": TASK_RUN_ID,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
}
],
}


class TestDatabricksExecutionTrigger:
Expand Down Expand Up @@ -101,15 +124,21 @@ def test_serialize(self):
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_url):
async def test_run_return_success(
self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output
):
mock_get_run_page_url.return_value = RUN_PAGE_URL
mock_get_run_state.return_value = RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message="",
result_state="SUCCESS",
)
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE

trigger_event = self.trigger.run()
async for event in trigger_event:
Expand All @@ -121,13 +150,49 @@ async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_ur
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"notebook_error": None,
}
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_run_return_failure(
self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output
):
mock_get_run_page_url.return_value = RUN_PAGE_URL
mock_get_run_state.return_value = RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message="",
result_state="FAILED",
)
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED

trigger_event = self.trigger.run()
async for event in trigger_event:
assert event == TriggerEvent(
{
"run_id": RUN_ID,
"run_state": RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED"
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"notebook_error": ERROR_MESSAGE,
}
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
async def test_sleep_between_retries(
self, mock_get_run_state, mock_sleep, mock_get_run, mock_get_run_output
):
mock_get_run_state.side_effect = [
RunState(
life_cycle_state=LIFE_CYCLE_STATE_PENDING,
Expand All @@ -140,6 +205,8 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
result_state="SUCCESS",
),
]
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE

trigger_event = self.trigger.run()
async for event in trigger_event:
Expand All @@ -151,6 +218,7 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"notebook_error": None,
}
)
mock_sleep.assert_called_once()
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/databricks/utils/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_validate_trigger_event_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"notebook_error": None
"notebook_error": None,
}
assert validate_trigger_event(event) is None

Expand Down

0 comments on commit f406db4

Please sign in to comment.