From db345b57d0f21b12e3b2783639a41092106d4ac3 Mon Sep 17 00:00:00 2001 From: Syed Hussaain <103602455+syedahsn@users.noreply.github.com> Date: Thu, 15 Dec 2022 14:59:07 -0800 Subject: [PATCH] Add EMR Notebook operators (#28312) --- airflow/providers/amazon/aws/operators/emr.py | 165 ++++++++++ airflow/providers/amazon/aws/sensors/emr.py | 74 ++++- .../operators/emr.rst | 42 +++ .../operators/test_emr_notebook_execution.py | 290 ++++++++++++++++++ .../sensors/test_emr_notebook_execution.py | 76 +++++ .../aws/example_emr_notebook_execution.py | 123 ++++++++ 6 files changed, 763 insertions(+), 7 deletions(-) create mode 100644 tests/providers/amazon/aws/operators/test_emr_notebook_execution.py create mode 100644 tests/providers/amazon/aws/sensors/test_emr_notebook_execution.py create mode 100644 tests/system/providers/amazon/aws/example_emr_notebook_execution.py diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 451c85e4d1453..f8a6929ea0e3b 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -117,6 +117,171 @@ def execute(self, context: Context) -> list[str]: return emr_hook.add_job_flow_steps(job_flow_id=job_flow_id, steps=steps, wait_for_completion=True) +class EmrStartNotebookExecutionOperator(BaseOperator): + """ + An operator that starts an EMR notebook execution. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrStartNotebookExecutionOperator` + + :param editor_id: The unique identifier of the EMR notebook to use for notebook execution. + :param relative_path: The path and file name of the notebook file for this execution, + relative to the path specified for the EMR notebook. + :param cluster_id: The unique identifier of the EMR cluster the notebook is attached to. + :param service_role: The name or ARN of the IAM role that is used as the service role + for Amazon EMR (the EMR role) for the notebook execution. + :param notebook_execution_name: Optional name for the notebook execution. + :param notebook_params: Input parameters in JSON format passed to the EMR notebook at + runtime for execution. + :param: notebook_instance_security_group_id: The unique identifier of the Amazon EC2 + security group to associate with the EMR notebook for this notebook execution. + :param: master_instance_security_group_id: Optional unique ID of an EC2 security + group to associate with the master instance of the EMR cluster for this notebook execution. + :param tags: Optional list of key value pair to associate with the notebook execution. + :param waiter_countdown: Total amount of time the operator will wait for the notebook to stop. + Defaults to 25 * 60 seconds. + :param waiter_check_interval_seconds: Number of seconds between polling the state of the notebook. + Defaults to 60 seconds. + """ + + template_fields: Sequence[str] = ( + "editor_id", + "cluster_id", + "relative_path", + "service_role", + "notebook_execution_name", + "notebook_params", + "notebook_instance_security_group_id", + "master_instance_security_group_id", + "tags", + ) + + def __init__( + self, + editor_id: str, + relative_path: str, + cluster_id: str, + service_role: str, + notebook_execution_name: str | None = None, + notebook_params: str | None = None, + notebook_instance_security_group_id: str | None = None, + master_instance_security_group_id: str | None = None, + tags: list | None = None, + wait_for_completion: bool = False, + aws_conn_id: str = "aws_default", + waiter_countdown: int = 25 * 60, + waiter_check_interval_seconds: int = 60, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.editor_id = editor_id + self.relative_path = relative_path + self.service_role = service_role + self.notebook_execution_name = notebook_execution_name or f"emr_notebook_{uuid4()}" + self.notebook_params = notebook_params or "" + self.notebook_instance_security_group_id = notebook_instance_security_group_id or "" + self.tags = tags or [] + self.wait_for_completion = wait_for_completion + self.cluster_id = cluster_id + self.aws_conn_id = aws_conn_id + self.waiter_countdown = waiter_countdown + self.waiter_check_interval_seconds = waiter_check_interval_seconds + self.master_instance_security_group_id = master_instance_security_group_id + + def execute(self, context: Context): + execution_engine = { + "Id": self.cluster_id, + "Type": "EMR", + "MasterInstanceSecurityGroupId": self.master_instance_security_group_id or "", + } + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) + + response = emr_hook.conn.start_notebook_execution( + EditorId=self.editor_id, + RelativePath=self.relative_path, + NotebookExecutionName=self.notebook_execution_name, + NotebookParams=self.notebook_params, + ExecutionEngine=execution_engine, + ServiceRole=self.service_role, + NotebookInstanceSecurityGroupId=self.notebook_instance_security_group_id, + Tags=self.tags, + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise AirflowException(f"Starting notebook execution failed: {response}") + + self.log.info("Notebook execution started: %s", response["NotebookExecutionId"]) + notebook_execution_id = response["NotebookExecutionId"] + if self.wait_for_completion: + waiter( + get_state_callable=emr_hook.conn.describe_notebook_execution, + get_state_args={"NotebookExecutionId": notebook_execution_id}, + parse_response=["NotebookExecution", "Status"], + desired_state={"RUNNING", "FINISHED"}, + failure_states={"FAILED"}, + object_type="notebook execution", + action="starting", + countdown=self.waiter_countdown, + check_interval_seconds=self.waiter_check_interval_seconds, + ) + return notebook_execution_id + + +class EmrStopNotebookExecutionOperator(BaseOperator): + """ + An operator that stops a running EMR notebook execution. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrStopNotebookExecutionOperator` + + :param notebook_execution_id: The unique identifier of the notebook execution. + :param wait_for_completion: If True, the operator will wait for the notebook. + to be in a STOPPED or FINISHED state. Defaults to False. + :param aws_conn_id: aws connection to use. + :param waiter_countdown: Total amount of time the operator will wait for the notebook to stop. + Defaults to 25 * 60 seconds. + :param waiter_check_interval_seconds: Number of seconds between polling the state of the notebook. + Defaults to 60 seconds. + """ + + template_fields: Sequence[str] = ("notebook_execution_id",) + + def __init__( + self, + notebook_execution_id: str, + wait_for_completion: bool = False, + aws_conn_id: str = "aws_default", + waiter_countdown: int = 25 * 60, + waiter_check_interval_seconds: int = 60, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.notebook_execution_id = notebook_execution_id + self.wait_for_completion = wait_for_completion + self.aws_conn_id = aws_conn_id + self.waiter_countdown = waiter_countdown + self.waiter_check_interval_seconds = waiter_check_interval_seconds + + def execute(self, context: Context) -> None: + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) + emr_hook.conn.stop_notebook_execution(NotebookExecutionId=self.notebook_execution_id) + + if self.wait_for_completion: + waiter( + get_state_callable=emr_hook.conn.describe_notebook_execution, + get_state_args={"NotebookExecutionId": self.notebook_execution_id}, + parse_response=["NotebookExecution", "Status"], + desired_state={"STOPPED", "FINISHED"}, + failure_states={"FAILED"}, + object_type="notebook execution", + action="stopped", + countdown=self.waiter_countdown, + check_interval_seconds=self.waiter_check_interval_seconds, + ) + + class EmrEksCreateClusterOperator(BaseOperator): """ An operator that creates EMR on EKS virtual clusters. diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index a3684fa249a1d..04e28977169bc 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -74,11 +74,7 @@ def poke(self, context: Context): return True if state in self.failed_states: - final_message = "EMR job failed" - failure_message = self.failure_message_from_response(response) - if failure_message: - final_message += " " + failure_message - raise AirflowException(final_message) + raise AirflowException(f"EMR job failed: {self.failure_message_from_response(response)}") return False @@ -93,7 +89,7 @@ def get_emr_response(self) -> dict[str, Any]: @staticmethod def state_from_response(response: dict[str, Any]) -> str: """ - Get state from response dictionary. + Get state from boto3 response. :param response: response from AWS API :return: state @@ -103,7 +99,7 @@ def state_from_response(response: dict[str, Any]) -> str: @staticmethod def failure_message_from_response(response: dict[str, Any]) -> str | None: """ - Get failure message from response dictionary. + Get state from boto3 response. :param response: response from AWS API :return: failure message @@ -299,6 +295,70 @@ def hook(self) -> EmrContainerHook: return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) +class EmrNotebookExecutionSensor(EmrBaseSensor): + """ + Polls the state of the EMR notebook execution until it reaches + any of the target states. + If a failure state is reached, the sensor throws an error, and fails the task. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EmrNotebookExecutionSensor` + + :param notebook_execution_id: Unique id of the notebook execution to be poked. + :target_states: the states the sensor will wait for the execution to reach. + Default target_states is ``FINISHED``. + :failed_states: if the execution reaches any of the failed_states, the sensor will fail. + Default failed_states is ``FAILED``. + """ + + template_fields: Sequence[str] = ("notebook_execution_id",) + + FAILURE_STATES = {"FAILED"} + COMPLETED_STATES = {"FINISHED"} + + def __init__( + self, + notebook_execution_id: str, + target_states: Iterable[str] | None = None, + failed_states: Iterable[str] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.notebook_execution_id = notebook_execution_id + self.target_states = target_states or self.COMPLETED_STATES + self.failed_states = failed_states or self.FAILURE_STATES + + def get_emr_response(self) -> dict[str, Any]: + emr_client = self.get_hook().get_conn() + self.log.info("Poking notebook %s", self.notebook_execution_id) + + return emr_client.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id) + + @staticmethod + def state_from_response(response: dict[str, Any]) -> str: + """ + Make an API call with boto3 and get cluster-level details. + + .. seealso:: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_cluster + + :return: response + """ + return response["NotebookExecution"]["Status"] + + @staticmethod + def failure_message_from_response(response: dict[str, Any]) -> str | None: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :return: failure message + """ + cluster_status = response["NotebookExecution"] + return cluster_status.get("LastStateChangeReason", None) + + class EmrJobFlowSensor(EmrBaseSensor): """ Asks for the state of the EMR JobFlow (Cluster) until it reaches diff --git a/docs/apache-airflow-providers-amazon/operators/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr.rst index a9c434777d71c..6597413700a8e 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr.rst @@ -124,9 +124,51 @@ To modify an existing EMR container you can use :start-after: [START howto_operator_emr_modify_cluster] :end-before: [END howto_operator_emr_modify_cluster] +.. _howto/operator:EmrStartNotebookExecutionOperator: + +Start an EMR notebook execution +==================================== + +You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrStartNotebookExecutionOperator` to +start a notebook execution on an existing notebook attached to a running cluster. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr_notebook_execution.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_start_notebook_execution] + :end-before: [END howto_operator_emr_start_notebook_execution] + +.. _howto/operator:EmrStopNotebookExecutionOperator: + +Stop an EMR notebook execution +==================================== + +You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrStopNotebookExecutionOperator` to +stop a running notebook execution. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr_notebook_execution.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_stop_notebook_execution] + :end-before: [END howto_operator_emr_stop_notebook_execution] + Sensors ------- +.. _howto/sensor:EmrNotebookExecutionSensor: + +Wait on an EMR notebook execution state +======================================= + +To monitor the state of an EMR notebook execution you can use +:class:`~airflow.providers.amazon.aws.sensors.emr.EmrNotebookExecutionSensor`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_emr_notebook_execution.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_emr_notebook_execution] + :end-before: [END howto_sensor_emr_notebook_execution] + .. _howto/sensor:EmrJobFlowSensor: Wait on an Amazon EMR job flow state diff --git a/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py b/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py new file mode 100644 index 0000000000000..dc8be6a2881d7 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py @@ -0,0 +1,290 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.operators.emr import ( + EmrStartNotebookExecutionOperator, + EmrStopNotebookExecutionOperator, +) + +PARAMS = { + "EditorId": "test_editor", + "RelativePath": "test_relative_path", + "ServiceRole": "test_role", + "NotebookExecutionName": "test_name", + "NotebookParams": "test_params", + "NotebookInstanceSecurityGroupId": "test_notebook_instance_security_group_id", + "Tags": [{"test_key": "test_value"}], + "ExecutionEngine": { + "Id": "test_cluster_id", + "Type": "EMR", + "MasterInstanceSecurityGroupId": "test_master_instance_security_group_id", + }, +} + + +class TestEmrStartNotebookExecutionOperator: + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_start_notebook_execution_wait_for_completion(self, mock_conn): + test_execution_id = "test-execution-id" + mock_conn.start_notebook_execution.return_value = { + "NotebookExecutionId": test_execution_id, + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + } + mock_conn.describe_notebook_execution.return_value = {"NotebookExecution": {"Status": "FINISHED"}} + + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + wait_for_completion=True, + ) + op_response = op.execute(None) + + mock_conn.start_notebook_execution.assert_called_once_with(**PARAMS) + mock_conn.describe_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + assert op_response == test_execution_id + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_start_notebook_execution_no_wait_for_completion(self, mock_conn): + test_execution_id = "test-execution-id" + mock_conn.start_notebook_execution.return_value = { + "NotebookExecutionId": test_execution_id, + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + } + + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + ) + op_response = op.execute(None) + + mock_conn.start_notebook_execution.assert_called_once_with(**PARAMS) + assert op.wait_for_completion is False + assert not mock_conn.describe_notebook_execution.called + assert op_response == test_execution_id + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_start_notebook_execution_http_code_fail(self, mock_conn): + test_execution_id = "test-execution-id" + mock_conn.start_notebook_execution.return_value = { + "NotebookExecutionId": test_execution_id, + "ResponseMetadata": { + "HTTPStatusCode": 400, + }, + } + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + ) + with pytest.raises(AirflowException, match=r"Starting notebook execution failed:"): + op.execute(None) + + mock_conn.start_notebook_execution.assert_called_once_with(**PARAMS) + + @mock.patch("time.sleep", return_value=None) + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_start_notebook_execution_wait_for_completion_multiple_attempts(self, mock_conn, _): + test_execution_id = "test-execution-id" + mock_conn.start_notebook_execution.return_value = { + "NotebookExecutionId": test_execution_id, + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + } + mock_conn.describe_notebook_execution.side_effect = [ + {"NotebookExecution": {"Status": "PENDING"}}, + {"NotebookExecution": {"Status": "PENDING"}}, + {"NotebookExecution": {"Status": "FINISHED"}}, + ] + + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + wait_for_completion=True, + ) + op_response = op.execute(None) + + mock_conn.start_notebook_execution.assert_called_once_with(**PARAMS) + mock_conn.describe_notebook_execution.assert_called_with(NotebookExecutionId=test_execution_id) + assert mock_conn.describe_notebook_execution.call_count == 3 + assert op_response == test_execution_id + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_start_notebook_execution_wait_for_completion_fail_state(self, mock_conn): + test_execution_id = "test-execution-id" + mock_conn.start_notebook_execution.return_value = { + "NotebookExecutionId": test_execution_id, + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + } + mock_conn.describe_notebook_execution.return_value = {"NotebookExecution": {"Status": "FAILED"}} + + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + wait_for_completion=True, + ) + with pytest.raises(AirflowException, match=r"Notebook Execution reached failure state FAILED\."): + op.execute(None) + mock_conn.start_notebook_execution.assert_called_once_with(**PARAMS) + mock_conn.describe_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + + +class TestStopEmrNotebookExecutionOperator: + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_stop_notebook_execution(self, mock_conn): + mock_conn.stop_notebook_execution.return_value = None + test_execution_id = "test-execution-id" + + op = EmrStopNotebookExecutionOperator( + task_id="test-id", + notebook_execution_id=test_execution_id, + ) + + op.execute(None) + + mock_conn.stop_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + assert not mock_conn.describe_notebook_execution.called + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_stop_notebook_execution_wait_for_completion(self, mock_conn): + mock_conn.stop_notebook_execution.return_value = None + mock_conn.describe_notebook_execution.return_value = {"NotebookExecution": {"Status": "FINISHED"}} + test_execution_id = "test-execution-id" + + op = EmrStopNotebookExecutionOperator( + task_id="test-id", notebook_execution_id=test_execution_id, wait_for_completion=True + ) + + op.execute(None) + mock_conn.stop_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + mock_conn.describe_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_stop_notebook_execution_wait_for_completion_fail_state(self, mock_conn): + mock_conn.stop_notebook_execution.return_value = None + mock_conn.describe_notebook_execution.return_value = {"NotebookExecution": {"Status": "FAILED"}} + test_execution_id = "test-execution-id" + + op = EmrStopNotebookExecutionOperator( + task_id="test-id", notebook_execution_id=test_execution_id, wait_for_completion=True + ) + + with pytest.raises(AirflowException, match=r"Notebook Execution reached failure state FAILED."): + op.execute(None) + mock_conn.stop_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + mock_conn.describe_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + + @mock.patch("time.sleep", return_value=None) + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_stop_notebook_execution_wait_for_completion_multiple_attempts(self, mock_conn, _): + mock_conn.stop_notebook_execution.return_value = None + mock_conn.describe_notebook_execution.side_effect = [ + {"NotebookExecution": {"Status": "PENDING"}}, + {"NotebookExecution": {"Status": "PENDING"}}, + {"NotebookExecution": {"Status": "FINISHED"}}, + ] + test_execution_id = "test-execution-id" + + op = EmrStopNotebookExecutionOperator( + task_id="test-id", notebook_execution_id=test_execution_id, wait_for_completion=True + ) + + op.execute(None) + mock_conn.stop_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + mock_conn.describe_notebook_execution.assert_called_with(NotebookExecutionId=test_execution_id) + assert mock_conn.describe_notebook_execution.call_count == 3 + + @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_stop_notebook_execution_waiter_config(self, mock_conn, mock_waiter): + mock_waiter.return_value = True + + test_execution_id = "test-execution-id" + + op = EmrStopNotebookExecutionOperator( + task_id="test-id", + notebook_execution_id=test_execution_id, + wait_for_completion=True, + waiter_countdown=400, + waiter_check_interval_seconds=12, + ) + + op.execute(None) + mock_conn.stop_notebook_execution.assert_called_once_with(NotebookExecutionId=test_execution_id) + mock_waiter.assert_called_once_with( + get_state_callable=mock_conn.describe_notebook_execution, + get_state_args={"NotebookExecutionId": test_execution_id}, + parse_response=["NotebookExecution", "Status"], + desired_state={"STOPPED", "FINISHED"}, + failure_states={"FAILED"}, + object_type="notebook execution", + action="stopped", + countdown=400, + check_interval_seconds=12, + ) diff --git a/tests/providers/amazon/aws/sensors/test_emr_notebook_execution.py b/tests/providers/amazon/aws/sensors/test_emr_notebook_execution.py new file mode 100644 index 0000000000000..9f051cfcef340 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_emr_notebook_execution.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.sensors.emr import EmrNotebookExecutionSensor + + +class TestEmrNotebookExecutionSensor: + def _generate_response(self, status: str, reason: str | None = None) -> dict[str, Any]: + return { + "NotebookExecution": { + "Status": status, + "LastStateChangeReason": reason, + }, + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + } + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_emr_notebook_execution_sensor_success_state(self, mock_conn): + mock_conn.describe_notebook_execution.return_value = self._generate_response("FINISHED") + sensor = EmrNotebookExecutionSensor( + task_id="test_task", + poke_interval=0, + notebook_execution_id="test-execution-id", + ) + sensor.poke(None) + mock_conn.describe_notebook_execution.assert_called_once_with(NotebookExecutionId="test-execution-id") + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_emr_notebook_execution_sensor_failed_state(self, mock_conn): + error_reason = "Test error" + mock_conn.describe_notebook_execution.return_value = self._generate_response("FAILED", error_reason) + sensor = EmrNotebookExecutionSensor( + task_id="test_task", + poke_interval=0, + notebook_execution_id="test-execution-id", + ) + with pytest.raises(AirflowException, match=rf"EMR job failed: {error_reason}"): + sensor.poke(None) + mock_conn.describe_notebook_execution.assert_called_once_with(NotebookExecutionId="test-execution-id") + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") + def test_emr_notebook_execution_sensor_success_state_multiple(self, mock_conn): + return_values = [self._generate_response("PENDING") for i in range(2)] + return_values = return_values.append(self._generate_response("FINISHED")) + mock_conn.describe_notebook_execution.side_effects = return_values + sensor = EmrNotebookExecutionSensor( + task_id="test_task", + poke_interval=0, + notebook_execution_id="test-execution-id", + ) + sensor.poke(None) + mock_conn.describe_notebook_execution.call_count == 3 diff --git a/tests/system/providers/amazon/aws/example_emr_notebook_execution.py b/tests/system/providers/amazon/aws/example_emr_notebook_execution.py new file mode 100644 index 0000000000000..e24d465832ee1 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_emr_notebook_execution.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.operators.emr import ( + EmrStartNotebookExecutionOperator, + EmrStopNotebookExecutionOperator, +) +from airflow.providers.amazon.aws.sensors.emr import EmrNotebookExecutionSensor +from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +DAG_ID = "example_emr_notebook" +# Externally fetched variables: +EDITOR_ID_KEY = "EDITOR_ID" +CLUSTER_ID_KEY = "CLUSTER_ID" + +sys_test_context_task = ( + SystemTestContextBuilder().add_variable(EDITOR_ID_KEY).add_variable(CLUSTER_ID_KEY).build() +) + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + schedule="@once", + catchup=False, + tags=["example"], +) as dag: + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + editor_id = test_context[EDITOR_ID_KEY] + cluster_id = test_context[CLUSTER_ID_KEY] + + # [START howto_operator_emr_start_notebook_execution] + start_execution = EmrStartNotebookExecutionOperator( + task_id="start_execution", + editor_id=editor_id, + cluster_id=cluster_id, + relative_path="EMR-System-Test.ipynb", + service_role="EMR_Notebooks_DefaultRole", + ) + # [END howto_operator_emr_start_notebook_execution] + + notebook_execution_id_1 = start_execution.output + + # [START howto_sensor_emr_notebook_execution] + wait_for_execution_start = EmrNotebookExecutionSensor( + task_id="wait_for_execution_start", + notebook_execution_id=notebook_execution_id_1, + target_states={"RUNNING"}, + poke_interval=5, + ) + # [END howto_sensor_emr_notebook_execution] + + # [START howto_operator_emr_stop_notebook_execution] + stop_execution = EmrStopNotebookExecutionOperator( + task_id="stop_execution", + notebook_execution_id=notebook_execution_id_1, + ) + # [END howto_operator_emr_stop_notebook_execution] + + wait_for_execution_stop = EmrNotebookExecutionSensor( + task_id="wait_for_execution_stop", + notebook_execution_id=notebook_execution_id_1, + target_states={"STOPPED"}, + poke_interval=5, + ) + finish_execution = EmrStartNotebookExecutionOperator( + task_id="finish_execution", + editor_id=editor_id, + cluster_id=cluster_id, + relative_path="EMR-System-Test.ipynb", + service_role="EMR_Notebooks_DefaultRole", + ) + notebook_execution_id_2 = finish_execution.output + wait_for_execution_finish = EmrNotebookExecutionSensor( + task_id="wait_for_execution_finish", + notebook_execution_id=notebook_execution_id_2, + poke_interval=5, + ) + + chain( + # TEST SETUP + test_context, + # TEST BODY + start_execution, + wait_for_execution_start, + stop_execution, + wait_for_execution_stop, + finish_execution, + # TEST TEARDOWN + wait_for_execution_finish, + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)