From e3b3ce864b188943fab10ff565167eeda118ead6 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 24 Apr 2024 18:38:20 +0545 Subject: [PATCH 1/5] Fix DataprocSubmitJobOperator in deferrable mode=True when task is marked as failed. --- .../google/cloud/operators/dataproc.py | 1 + .../google/cloud/triggers/dataproc.py | 30 ++++++--- .../google/cloud/triggers/test_dataproc.py | 64 ++++++++++++++++++- 3 files changed, 84 insertions(+), 11 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index edbfbd3f39b45..9af81247dfa67 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -2591,6 +2591,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, polling_interval_seconds=self.polling_interval_seconds, + cancel_on_kill=self.cancel_on_kill, ), method_name="execute_complete", ) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index f0aecddb4a8ed..c767d1d73ab3d 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -43,6 +43,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, polling_interval_seconds: int = 30, + cancel_on_kill: bool = True, ): super().__init__() self.region = region @@ -50,6 +51,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.polling_interval_seconds = polling_interval_seconds + self.cancel_on_kill = cancel_on_kill def get_async_hook(self): return DataprocAsyncHook( @@ -91,20 +93,28 @@ def serialize(self): "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "polling_interval_seconds": self.polling_interval_seconds, + "cancel_on_kill": self.cancel_on_kill, }, ) async def run(self): - while True: - job = await self.get_async_hook().get_job( - project_id=self.project_id, region=self.region, job_id=self.job_id - ) - state = job.status.state - self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) - if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): - break - await asyncio.sleep(self.polling_interval_seconds) - yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) + try: + while True: + job = await self.get_async_hook().get_job( + project_id=self.project_id, region=self.region, job_id=self.job_id + ) + state = job.status.state + self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) + if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): + break + await asyncio.sleep(self.polling_interval_seconds) + yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) + except asyncio.CancelledError: + self.log.info("Task got cancelled.") + if self.job_id and self.cancel_on_kill: + await self.get_async_hook().cancel_job( + job_id=self.job_id, project_id=self.project_id, region=self.region + ) class DataprocClusterTrigger(DataprocBaseTrigger): diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 45607d51b8a59..3f584c8be7484 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -22,7 +22,7 @@ from unittest import mock import pytest -from google.cloud.dataproc_v1 import Batch, ClusterStatus +from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus from google.protobuf.any_pb2 import Any from google.rpc.status_pb2 import Status @@ -30,6 +30,7 @@ DataprocBatchTrigger, DataprocClusterTrigger, DataprocOperationTrigger, + DataprocSubmitTrigger, ) from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.triggers.base import TriggerEvent @@ -47,6 +48,7 @@ TEST_POLL_INTERVAL = 5 TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" +TEST_JOB_ID = "test-job-id" @pytest.fixture @@ -111,6 +113,17 @@ def func(**kwargs): return func +@pytest.fixture +def submit_trigger(): + return DataprocSubmitTrigger( + job_id=TEST_JOB_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + polling_interval_seconds=TEST_POLL_INTERVAL, + ) + + @pytest.fixture def async_get_batch(): def func(**kwargs): @@ -375,3 +388,52 @@ async def test_async_operation_triggers_on_error(self, mock_hook, operation_trig ) actual_event = await operation_trigger.run().asend(None) assert expected_event == actual_event + + +@pytest.mark.db_test +class TestDataprocSubmitTrigger: + def test_submit_trigger_serialization(self, submit_trigger): + """Test that the trigger serializes its configuration correctly.""" + classpath, kwargs = submit_trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger" + assert kwargs == { + "job_id": TEST_JOB_ID, + "project_id": TEST_PROJECT_ID, + "region": TEST_REGION, + "gcp_conn_id": TEST_GCP_CONN_ID, + "polling_interval_seconds": TEST_POLL_INTERVAL, + "cancel_on_kill": True, + "impersonation_chain": None, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + async def test_submit_trigger_run_success(self, mock_get_async_hook, submit_trigger): + """Test the trigger correctly handles a job completion.""" + mock_hook = mock_get_async_hook.return_value + mock_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE)) + ) + + async_gen = submit_trigger.run() + event = await async_gen.asend(None) + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigger): + """Test the trigger correctly handles a job error.""" + mock_hook = mock_get_async_hook.return_value + mock_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.ERROR)) + ) + + async_gen = submit_trigger.run() + event = await async_gen.asend(None) + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload From 455886a2caba4683b1e980d487f2edf2caba7cb6 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 24 Apr 2024 21:20:23 +0545 Subject: [PATCH 2/5] Add the test for CancelledError --- .../google/cloud/triggers/test_dataproc.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 3f584c8be7484..bd3c69be59f9d 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -437,3 +437,41 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_hook.get_job.return_value} ) assert event.payload == expected_event.payload + + import asyncio + from unittest.mock import AsyncMock, patch + + import pytest + + @pytest.mark.asyncio + @patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + async def test_submit_trigger_run_cancelled(self, mock_get_async_hook, submit_trigger): + """Test the trigger correctly handles an asyncio.CancelledError.""" + mock_hook = mock_get_async_hook.return_value + mock_hook.get_job = mock.AsyncMock(side_effect=asyncio.CancelledError) + mock_hook.cancel_job = mock.AsyncMock() + + async_gen = submit_trigger.run() + + try: + await async_gen.__anext__() + # If no error is raised, assert a failure as we expected an exception + pytest.fail("Expected an asyncio.CancelledError but didn't get one.") + except asyncio.CancelledError: + assert True, "asyncio.CancelledError was caught as expected." + except StopAsyncIteration: + # If StopAsyncIteration is raised, check if the CancelledError was handled internally + # Since we caught StopAsyncIteration, it means the async generator concluded without issue + # Verify if cancel_job was called if cancel_on_kill is True + if submit_trigger.cancel_on_kill: + mock_hook.cancel_job.assert_called_once_with( + job_id=submit_trigger.job_id, + project_id=submit_trigger.project_id, + region=submit_trigger.region, + ) + else: + mock_hook.cancel_job.assert_not_called() + assert True, "Cancellation was handled internally and properly." + finally: + # Clean up generator + await async_gen.aclose() From 395e8dbe45f7dd3a0445f35f74693cd8c0ada5b0 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 25 Apr 2024 15:49:08 +0545 Subject: [PATCH 3/5] Use sync hook to cancel job and raise error --- .../google/cloud/triggers/dataproc.py | 36 ++++++++++++++++--- .../google/cloud/triggers/test_dataproc.py | 7 +--- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index c767d1d73ab3d..d04a49d16db0a 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -27,7 +27,8 @@ from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus -from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -59,6 +60,16 @@ def get_async_hook(self): impersonation_chain=self.impersonation_chain, ) + def get_sync_hook(self): + # The synchronous hook is utilized to delete the cluster when a task is cancelled. + # This is because the asynchronous hook deletion is not awaited when the trigger task + # is cancelled. The call for deleting the cluster or job through the sync hook is not a blocking + # call, which means it does not wait until the cluster or job is deleted. + return DataprocHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class DataprocSubmitTrigger(DataprocBaseTrigger): """ @@ -111,10 +122,25 @@ async def run(self): yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) except asyncio.CancelledError: self.log.info("Task got cancelled.") - if self.job_id and self.cancel_on_kill: - await self.get_async_hook().cancel_job( - job_id=self.job_id, project_id=self.project_id, region=self.region - ) + try: + if self.job_id and self.cancel_on_kill: + self.log.info("Cancelling the job: %s", self.job_id) + # The synchronous hook is utilized to delete the cluster when a task is cancelled. This + # is because the asynchronous hook deletion is not awaited when the trigger task is + # cancelled. The call for deleting the cluster or job through the sync hook is not a + # blocking call, which means it does not wait until the cluster or job is deleted. + self.get_sync_hook().cancel_job( + job_id=self.job_id, project_id=self.project_id, region=self.region + ) + self.log.info("Job: %s is cancelled", self.job_id) + yield TriggerEvent( + {"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING, "job": job} + ) + except Exception as e: + self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) + raise AirflowException( + f"Failed to cancel the job: {self.job_id} with error : {str(e)}" + ) from e class DataprocClusterTrigger(DataprocBaseTrigger): diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index bd3c69be59f9d..cbcc2827a9725 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -438,13 +438,8 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge ) assert event.payload == expected_event.payload - import asyncio - from unittest.mock import AsyncMock, patch - - import pytest - @pytest.mark.asyncio - @patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") async def test_submit_trigger_run_cancelled(self, mock_get_async_hook, submit_trigger): """Test the trigger correctly handles an asyncio.CancelledError.""" mock_hook = mock_get_async_hook.return_value From fc6f54d36eb5b868adb8f601b75bcca4a1c7f775 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 25 Apr 2024 16:25:48 +0545 Subject: [PATCH 4/5] Fix the test --- .../google/cloud/triggers/dataproc.py | 4 +- .../google/cloud/triggers/test_dataproc.py | 55 +++++++++++-------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index d04a49d16db0a..6f54eaecc3ee4 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -133,9 +133,7 @@ async def run(self): job_id=self.job_id, project_id=self.project_id, region=self.region ) self.log.info("Job: %s is cancelled", self.job_id) - yield TriggerEvent( - {"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING, "job": job} - ) + yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING}) except Exception as e: self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) raise AirflowException( diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index cbcc2827a9725..f2514efb18899 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -440,33 +440,42 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") - async def test_submit_trigger_run_cancelled(self, mock_get_async_hook, submit_trigger): + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook") + async def test_submit_trigger_run_cancelled( + self, mock_get_sync_hook, mock_get_async_hook, submit_trigger + ): """Test the trigger correctly handles an asyncio.CancelledError.""" - mock_hook = mock_get_async_hook.return_value - mock_hook.get_job = mock.AsyncMock(side_effect=asyncio.CancelledError) - mock_hook.cancel_job = mock.AsyncMock() + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job.side_effect = asyncio.CancelledError + + mock_sync_hook = mock_get_sync_hook.return_value + mock_sync_hook.cancel_job = mock.MagicMock() async_gen = submit_trigger.run() try: - await async_gen.__anext__() - # If no error is raised, assert a failure as we expected an exception - pytest.fail("Expected an asyncio.CancelledError but didn't get one.") + await async_gen.asend(None) + # Should raise StopAsyncIteration if no more items to yield + await async_gen.asend(None) except asyncio.CancelledError: - assert True, "asyncio.CancelledError was caught as expected." + # Handle the cancellation as expected + pass except StopAsyncIteration: - # If StopAsyncIteration is raised, check if the CancelledError was handled internally - # Since we caught StopAsyncIteration, it means the async generator concluded without issue - # Verify if cancel_job was called if cancel_on_kill is True - if submit_trigger.cancel_on_kill: - mock_hook.cancel_job.assert_called_once_with( - job_id=submit_trigger.job_id, - project_id=submit_trigger.project_id, - region=submit_trigger.region, - ) - else: - mock_hook.cancel_job.assert_not_called() - assert True, "Cancellation was handled internally and properly." - finally: - # Clean up generator - await async_gen.aclose() + # The generator should be properly closed after handling the cancellation + pass + except Exception as e: + # Catch any other exceptions that should not occur + pytest.fail(f"Unexpected exception raised: {e}") + + # Check if cancel_job was correctly called + if submit_trigger.cancel_on_kill: + mock_sync_hook.cancel_job.assert_called_once_with( + job_id=submit_trigger.job_id, + project_id=submit_trigger.project_id, + region=submit_trigger.region, + ) + else: + mock_sync_hook.cancel_job.assert_not_called() + + # Clean up the generator + await async_gen.aclose() From 4827665df9182810f3594ca37d6a0ac4364a0676 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:52:50 +0545 Subject: [PATCH 5/5] Fix PR comment --- airflow/providers/google/cloud/triggers/dataproc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 245ea7fd75a16..427bf8a09615c 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -138,9 +138,7 @@ async def run(self): yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING}) except Exception as e: self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) - raise AirflowException( - f"Failed to cancel the job: {self.job_id} with error : {str(e)}" - ) from e + raise e class DataprocClusterTrigger(DataprocBaseTrigger):