Skip to content

Commit

Permalink
Fix BigQueryInsertJobOperator not exiting deferred state (#31591)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro authored Jul 29, 2023
1 parent fcbbf47 commit 81b85eb
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 88 deletions.
31 changes: 9 additions & 22 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3103,29 +3103,16 @@ async def get_job_instance(
with await self.service_file_as_context() as f:
return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session))

async def get_job_status(
self,
job_id: str | None,
project_id: str | None = None,
) -> str | None:
"""Poll for job status asynchronously using gcloud-aio.
Note that an OSError is raised when Job results are still pending.
Exception means that Job finished with errors
"""
async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> str:
async with ClientSession() as s:
try:
self.log.info("Executing get_job_status...")
job_client = await self.get_job_instance(project_id, job_id, s)
job_status_response = await job_client.result(cast(Session, s))
if job_status_response:
job_status = "success"
except OSError:
job_status = "pending"
except Exception as e:
self.log.info("Query execution finished with errors...")
job_status = str(e)
return job_status
job_client = await self.get_job_instance(project_id, job_id, s)
job = await job_client.get_job()
status = job.get("status", {})
if status["state"] == "DONE":
if "errorResult" in status:
return "error"
return "success"
return status["state"].lower()

async def get_job_output(
self,
Expand Down
63 changes: 31 additions & 32 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,29 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current job execution status and yields a TriggerEvent."""
"""Gets current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
# Poll for job execution status
response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
self.log.debug("Response from hook: %s", response_from_hook)

if response_from_hook == "success":
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if job_status == "success":
yield TriggerEvent(
{
"job_id": self.job_id,
"status": "success",
"status": job_status,
"message": "Job completed",
}
)
return
elif response_from_hook == "pending":
self.log.info("Query is still running...")
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent({"status": "error", "message": response_from_hook})
elif job_status == "error":
yield TriggerEvent({"status": "error"})
return

else:
self.log.info(
"Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
Expand Down Expand Up @@ -129,8 +127,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
while True:
try:
# Poll for job execution status
response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if response_from_hook == "success":
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if job_status == "success":
query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)

records = hook.get_records(query_results)
Expand All @@ -154,14 +152,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
}
)
return

elif response_from_hook == "pending":
self.log.info("Query is still running...")
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent({"status": "error", "message": response_from_hook})
elif job_status == "error":
yield TriggerEvent({"status": "error", "message": job_status})
return
else:
self.log.info(
"Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
Expand Down Expand Up @@ -201,26 +199,27 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
while True:
try:
# Poll for job execution status
response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if response_from_hook == "success":
job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if job_status == "success":
query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
records = hook.get_records(query_results=query_results, as_dict=self.as_dict)
self.log.debug("Response from hook: %s", response_from_hook)
self.log.debug("Response from hook: %s", job_status)
yield TriggerEvent(
{
"status": "success",
"message": response_from_hook,
"message": job_status,
"records": records,
}
)
return
elif response_from_hook == "pending":
self.log.info("Query is still running...")
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent({"status": "error", "message": response_from_hook})
elif job_status == "error":
yield TriggerEvent({"status": "error"})
return
else:
self.log.info(
"Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
Expand Down
33 changes: 11 additions & 22 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2139,34 +2139,23 @@ async def test_get_job_instance(self, mock_session):
result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session)
assert isinstance(result, Job)

@pytest.mark.parametrize(
"job_status, expected",
[
({"status": {"state": "DONE"}}, "success"),
({"status": {"state": "DONE", "errorResult": "Timeout"}}, "error"),
({"status": {"state": "running"}}, "running"),
],
)
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
async def test_get_job_status_success(self, mock_job_instance):
async def test_get_job_status(self, mock_job_instance, job_status, expected):
hook = BigQueryAsyncHook()
mock_job_client = AsyncMock(Job)
mock_job_instance.return_value = mock_job_client
response = "success"
mock_job_instance.return_value.result.return_value = response
mock_job_instance.return_value.get_job.return_value = job_status
resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
assert resp == response

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
async def test_get_job_status_oserror(self, mock_job_instance):
"""Assets that the BigQueryAsyncHook returns a pending response when OSError is raised"""
mock_job_instance.return_value.result.side_effect = OSError()
hook = BigQueryAsyncHook()
job_status = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
assert job_status == "pending"

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
async def test_get_job_status_exception(self, mock_job_instance, caplog):
"""Assets that the logging is done correctly when BigQueryAsyncHook raises Exception"""
mock_job_instance.return_value.result.side_effect = Exception()
hook = BigQueryAsyncHook()
await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
assert "Query execution finished with errors..." in caplog.text
assert resp == expected

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
Expand Down
20 changes: 8 additions & 12 deletions tests/providers/google/cloud/triggers/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def test_bigquery_insert_job_trigger_running(self, mock_job_instance, capl

mock_job_client = AsyncMock(Job)
mock_job_instance.return_value = mock_job_client
mock_job_instance.return_value.result.side_effect = OSError
mock_job_instance.return_value.get_job.return_value = {"status": {"state": "running"}}
caplog.set_level(logging.INFO)

task = asyncio.create_task(insert_job_trigger.run().__anext__())
Expand All @@ -189,8 +189,7 @@ async def test_bigquery_insert_job_trigger_running(self, mock_job_instance, capl
# TriggerEvent was not returned
assert task.done() is False

assert "Query is still running..." in caplog.text
assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
assert "Bigquery job status is running. Sleeping for 4.0 seconds." in caplog.text

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
Expand All @@ -205,7 +204,7 @@ async def test_bigquery_op_trigger_terminated(self, mock_job_status, caplog, ins

generator = insert_job_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "error"}) == actual
assert TriggerEvent({"status": "error"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
Expand Down Expand Up @@ -241,7 +240,7 @@ async def test_bigquery_get_data_trigger_running(self, mock_job_instance, caplog

mock_job_client = AsyncMock(Job)
mock_job_instance.return_value = mock_job_client
mock_job_instance.return_value.result.side_effect = OSError
mock_job_instance.return_value.get_job.return_value = {"status": {"state": "RUNNING"}}
caplog.set_level(logging.INFO)

task = asyncio.create_task(get_data_trigger.run().__anext__())
Expand All @@ -250,8 +249,7 @@ async def test_bigquery_get_data_trigger_running(self, mock_job_instance, caplog
# TriggerEvent was not returned
assert task.done() is False

assert "Query is still running..." in caplog.text
assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
assert "Bigquery job status is running. Sleeping for 4.0 seconds." in caplog.text

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
Expand All @@ -266,7 +264,7 @@ async def test_bigquery_get_data_trigger_terminated(self, mock_job_status, caplo

generator = get_data_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "error"}) == actual
assert TriggerEvent({"status": "error"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
Expand Down Expand Up @@ -336,17 +334,15 @@ async def test_bigquery_check_trigger_running(self, mock_job_instance, caplog, c

mock_job_client = AsyncMock(Job)
mock_job_instance.return_value = mock_job_client
mock_job_instance.return_value.result.side_effect = OSError
caplog.set_level(logging.INFO)
mock_job_instance.return_value.get_job.return_value = {"status": {"state": "running"}}

task = asyncio.create_task(check_trigger.run().__anext__())
await asyncio.sleep(0.5)

# TriggerEvent was not returned
assert task.done() is False

assert "Query is still running..." in caplog.text
assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
assert "Bigquery job status is running. Sleeping for 4.0 seconds." in caplog.text

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
Expand Down

0 comments on commit 81b85eb

Please sign in to comment.