diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index cf903b8ddc65f..a0fabb324d0ba 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -27,6 +27,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -55,6 +56,10 @@ class EmrAddStepsOperator(BaseOperator): :param wait_for_completion: If True, the operator will wait for all the steps to be completed. :param execution_role_arn: The ARN of the runtime role for a step on the cluster. :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. + :param wait_for_completion: Whether to wait for job run completion. (default: True) + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ( @@ -84,6 +89,7 @@ def __init__( waiter_delay: int | None = None, waiter_max_attempts: int | None = None, execution_role_arn: str | None = None, + deferrable: bool = False, **kwargs, ): if not exactly_one(job_flow_id is None, job_flow_name is None): @@ -96,10 +102,11 @@ def __init__( self.job_flow_name = job_flow_name self.cluster_states = cluster_states self.steps = steps - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.execution_role_arn = execution_role_arn + self.deferrable = deferrable def execute(self, context: Context) -> list[str]: emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) @@ -137,7 +144,7 @@ def execute(self, context: Context) -> list[str]: steps = self.steps if isinstance(steps, str): steps = ast.literal_eval(steps) - return emr_hook.add_job_flow_steps( + step_ids = emr_hook.add_job_flow_steps( job_flow_id=job_flow_id, steps=steps, wait_for_completion=self.wait_for_completion, @@ -145,6 +152,26 @@ def execute(self, context: Context) -> list[str]: waiter_max_attempts=self.waiter_max_attempts, execution_role_arn=self.execution_role_arn, ) + if self.deferrable: + self.defer( + trigger=EmrAddStepsTrigger( + job_flow_id=job_flow_id, + step_ids=step_ids, + aws_conn_id=self.aws_conn_id, + max_attempts=self.waiter_max_attempts, + poll_interval=self.waiter_delay, + ), + method_name="execute_complete", + ) + + return step_ids + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error resuming cluster: {event}") + else: + self.log.info("Steps completed successfully") + return event["step_ids"] class EmrStartNotebookExecutionOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py new file mode 100644 index 0000000000000..2afc1c45afbdf --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any + +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class EmrAddStepsTrigger(BaseTrigger): + """ + AWS Emr Add Steps Trigger + The trigger will asynchronously poll the boto3 API and wait for the + steps to finish executing. + :param job_flow_id: The id of the job flow. + :param step_ids: The id of the steps being waited upon. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + job_flow_id: str, + step_ids: list[str], + aws_conn_id: str, + max_attempts: int | None, + poll_interval: int | None, + ): + self.job_flow_id = job_flow_id + self.step_ids = step_ids + self.aws_conn_id = aws_conn_id + self.max_attempts = max_attempts + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger", + { + "job_flow_id": str(self.job_flow_id), + "step_ids": self.step_ids, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + async def run(self): + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) + async with self.hook.async_conn as client: + for step_id in self.step_ids: + attempt = 0 + waiter = client.get_waiter("step_complete") + while attempt < int(self.max_attempts): + attempt += 1 + try: + await waiter.wait( + ClusterId=self.job_flow_id, + StepId=step_id, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Step {step_id} failed: {error}"} + ) + break + self.log.info( + "Status of step is %s - %s", + error.last_response["Step"]["Status"]["State"], + error.last_response["Step"]["Status"]["StateChangeReason"], + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + yield TriggerEvent({"status": "failure", "message": "Steps failed: max attempts reached"}) + else: + yield TriggerEvent({"status": "success", "message": "Steps completed", "step_ids": self.step_ids}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 716e3d8432cfb..9725a55bbaa15 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -522,6 +522,9 @@ triggers: python-modules: - airflow.providers.amazon.aws.triggers.glue - airflow.providers.amazon.aws.triggers.glue_crawler + - integration-name: Amazon EMR + python-modules: + - airflow.providers.amazon.aws.triggers.emr transfers: - source-integration-name: Amazon DynamoDB diff --git a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst index bd4e3e78c3a18..6cd628eb239cd 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr/emr.rst @@ -89,6 +89,10 @@ Add Steps to an EMR job flow To add steps to an existing EMR Job flow you can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrAddStepsOperator`. +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. +Using ``deferrable`` mode will release worker slots and leads to efficient utilization of +resources within Airflow cluster.However this mode will need the Airflow triggerer to be +available in your deployment. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 5171815eb37b0..0b279c051f9c6 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -25,11 +25,12 @@ import pytest from jinja2 import StrictUndefined -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.utils import timezone from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -244,3 +245,36 @@ def test_wait_for_completion(self, mock_add_job_flow_steps, *_): waiter_max_attempts=None, execution_role_arn=None, ) + + def test_wait_for_completion_false_with_deferrable(self): + job_flow_id = "j-8989898989" + operator = EmrAddStepsOperator( + task_id="test_task", + job_flow_id=job_flow_id, + aws_conn_id="aws_default", + dag=DAG("test_dag_id", default_args=self.args), + wait_for_completion=True, + deferrable=True, + ) + + assert operator.wait_for_completion is False + + @patch("airflow.providers.amazon.aws.operators.emr.get_log_uri") + @patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps") + def test_emr_add_steps_deferrable(self, mock_add_job_flow_steps, mock_get_log_uri): + mock_add_job_flow_steps.return_value = "test_step_id" + mock_get_log_uri.return_value = "test/log/uri" + job_flow_id = "j-8989898989" + operator = EmrAddStepsOperator( + task_id="test_task", + job_flow_id=job_flow_id, + aws_conn_id="aws_default", + dag=DAG("test_dag_id", default_args=self.args), + wait_for_completion=True, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + operator.execute(self.mock_context) + + assert isinstance(exc.value.trigger, EmrAddStepsTrigger), "Trigger is not a EmrAddStepsTrigger" diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py b/tests/providers/amazon/aws/triggers/test_emr_trigger.py new file mode 100644 index 0000000000000..0ec3b5af6eb8c --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr_trigger.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger +from airflow.triggers.base import TriggerEvent + +TEST_JOB_FLOW_ID = "test_job_flow_id" +TEST_STEP_IDS = ["step1", "step2"] +TEST_AWS_CONN_ID = "test-aws-id" +TEST_MAX_ATTEMPTS = 10 +TEST_POLL_INTERVAL = 10 + + +class TestEmrAddStepsTrigger: + def test_emr_add_steps_trigger_serialize(self): + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=TEST_STEP_IDS, + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPTS, + poll_interval=TEST_POLL_INTERVAL, + ) + class_path, args = emr_add_steps_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger" + assert args["job_flow_id"] == TEST_JOB_FLOW_ID + assert args["step_ids"] == TEST_STEP_IDS + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + a_mock.get_waiter().wait = AsyncMock() + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=TEST_STEP_IDS, + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPTS, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "message": "Steps completed", "step_ids": TEST_STEP_IDS} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "Running", "StateChangeReason": "test_reason"}}}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True, error, error, True]) + mock_sleep.return_value = True + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=TEST_STEP_IDS, + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPTS, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 6 + assert response == TriggerEvent( + {"status": "success", "message": "Steps completed", "step_ids": TEST_STEP_IDS} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "Running", "StateChangeReason": "test_reason"}}}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=[TEST_STEP_IDS[0]], + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=2, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 2 + assert response == TriggerEvent( + {"status": "failure", "message": "Steps failed: max attempts reached"} + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_add_steps_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_running = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "Running", "StateChangeReason": "test_reason"}}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"Step": {"Status": {"State": "FAILED", "StateChangeReason": "test_reason"}}}, + ) + a_mock.get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_running, error_running, error_failed] + ) + mock_sleep.return_value = True + + emr_add_steps_trigger = EmrAddStepsTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_ids=[TEST_STEP_IDS[0]], + aws_conn_id=TEST_AWS_CONN_ID, + max_attempts=TEST_MAX_ATTEMPTS, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = emr_add_steps_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + {"status": "failure", "message": f"Step {TEST_STEP_IDS[0]} failed: {error_failed}"} + )