From cf77c3b96609aa8c260566274d54b06eb38c8100 Mon Sep 17 00:00:00 2001 From: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> Date: Tue, 14 Mar 2023 01:31:36 +0530 Subject: [PATCH] Add deferrable mode in RedshiftPauseClusterOperator (#28850) In this PR, I'm adding `aiobotocore` as an additional dependency until https://github.com/aio-libs/aiobotocore/issues/976 has been resolved. I'm adding `AwsBaseAsyncHook` a basic async AWS hook. This will at present support the default `botocore` auth i'e if airflow connection is not provided then auth using ENV and if airflow connection is provided then basic auth with secret-key/access-key-id/profile/token and arn-method. maybe we can support the other auth incrementally depending on the community interest. Because the dependency making things a little bit complicated so I have created a new test module `deferrable` inside AWS provider tests and I'm keeping the async-related test in this particular module. I have also added a new CI job to test/run the AWS deferable operator tests and I'm ignoring the deferable tests in other CI job test runs. Add a trigger class which will wait until the Redshift pause request reaches a terminal state i.e paused or fail, We also have retry logic like the sync operator. This PR donates the following developed RedshiftPauseClusterOperatorAsync` in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow. --- .github/workflows/ci.yml | 24 ++++ .../providers/amazon/aws/hooks/base_aws.py | 131 +++++++++++++++++- .../amazon/aws/hooks/redshift_cluster.py | 83 ++++++++++- .../amazon/aws/operators/redshift_cluster.py | 60 ++++++-- .../providers/amazon/aws/triggers/__init__.py | 17 +++ .../amazon/aws/triggers/redshift_cluster.py | 77 ++++++++++ airflow/providers/amazon/provider.yaml | 4 + .../deferrable.rst | 33 +++++ .../apache-airflow-providers-amazon/index.rst | 1 + .../operators/redshift_cluster.rst | 5 +- docs/spelling_wordlist.txt | 2 + generated/provider_dependencies.json | 1 + .../amazon/aws/deferrable/__init__.py | 16 +++ .../amazon/aws/deferrable/hooks/__init__.py | 16 +++ .../aws/deferrable/hooks/test_base_aws.py | 102 ++++++++++++++ .../deferrable/hooks/test_redshift_cluster.py | 84 +++++++++++ .../aws/deferrable/triggers/__init__.py | 17 +++ .../triggers/test_redshift_cluster.py | 93 +++++++++++++ .../aws/operators/test_redshift_cluster.py | 36 ++++- tests/providers/amazon/aws/utils/compat.py | 37 +++++ 20 files changed, 821 insertions(+), 18 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/__init__.py create mode 100644 airflow/providers/amazon/aws/triggers/redshift_cluster.py create mode 100644 docs/apache-airflow-providers-amazon/deferrable.rst create mode 100644 tests/providers/amazon/aws/deferrable/__init__.py create mode 100644 tests/providers/amazon/aws/deferrable/hooks/__init__.py create mode 100644 tests/providers/amazon/aws/deferrable/hooks/test_base_aws.py create mode 100644 tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py create mode 100644 tests/providers/amazon/aws/deferrable/triggers/__init__.py create mode 100644 tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py create mode 100644 tests/providers/amazon/aws/utils/compat.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64a95a87058bd..224de35488010 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -799,6 +799,30 @@ jobs: run: breeze ci fix-ownership if: always() + tests-aws-async-provider: + timeout-minutes: 50 + name: "Pytest for AWS Async Provider" + runs-on: "${{needs.build-info.outputs.runs-on}}" + needs: [build-info, wait-for-ci-images] + if: needs.build-info.outputs.run-tests == 'true' + steps: + - name: Cleanup repo + shell: bash + run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*" + - name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )" + uses: actions/checkout@v3 + with: + persist-credentials: false + - name: "Prepare breeze & CI image" + uses: ./.github/actions/prepare_breeze_and_image + - name: "Run AWS Async Test" + run: "breeze shell \ + 'pip install aiobotocore>=2.1.1 && pytest /opt/airflow/tests/providers/amazon/aws/deferrable'" + - name: "Post Tests" + uses: ./.github/actions/post_tests + - name: "Fix ownership" + run: breeze ci fix-ownership + if: always() tests-helm: timeout-minutes: 80 diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index ca60902fc075c..d797653476ef6 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -51,7 +51,10 @@ from airflow.compat.functools import cached_property from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowNotFoundException +from airflow.exceptions import ( + AirflowException, + AirflowNotFoundException, +) from airflow.hooks.base import BaseHook from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter @@ -62,7 +65,6 @@ BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource]) - if TYPE_CHECKING: from airflow.models.connection import Connection # Avoid circular imports. @@ -877,3 +879,128 @@ def _parse_s3_config(config_file_name: str, config_format: str | None = "boto", config_format=config_format, profile=profile, ) + + +try: + import aiobotocore.credentials + from aiobotocore.session import AioSession, get_session +except ImportError: + pass + + +class BaseAsyncSessionFactory(BaseSessionFactory): + """ + Base AWS Session Factory class to handle aiobotocore session creation. + + It currently, handles ENV, AWS secret key and STS client method ``assume_role`` + provided in Airflow connection + """ + + async def get_role_credentials(self) -> dict: + """Get the role_arn, method credentials from connection details and get the role credentials detail""" + async with self._basic_session.create_client("sts", region_name=self.region_name) as client: + response = await client.assume_role( + RoleArn=self.role_arn, + RoleSessionName=self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"), + **self.conn.assume_role_kwargs, + ) + return response["Credentials"] + + async def _get_refresh_credentials(self) -> dict[str, Any]: + self.log.debug("Refreshing credentials") + assume_role_method = self.conn.assume_role_method + if assume_role_method != "assume_role": + raise NotImplementedError(f"assume_role_method={assume_role_method} not expected") + + credentials = await self.get_role_credentials() + + expiry_time = credentials["Expiration"].isoformat() + self.log.debug("New credentials expiry_time: %s", expiry_time) + credentials = { + "access_key": credentials.get("AccessKeyId"), + "secret_key": credentials.get("SecretAccessKey"), + "token": credentials.get("SessionToken"), + "expiry_time": expiry_time, + } + return credentials + + def _get_session_with_assume_role(self) -> AioSession: + + assume_role_method = self.conn.assume_role_method + if assume_role_method != "assume_role": + raise NotImplementedError(f"assume_role_method={assume_role_method} not expected") + + credentials = aiobotocore.credentials.AioRefreshableCredentials.create_from_metadata( + metadata=self._get_refresh_credentials(), + refresh_using=self._get_refresh_credentials, + method="sts-assume-role", + ) + + session = aiobotocore.session.get_session() + session._credentials = credentials + return session + + @cached_property + def _basic_session(self) -> AioSession: + """Cached property with basic aiobotocore.session.AioSession.""" + session_kwargs = self.conn.session_kwargs + aws_access_key_id = session_kwargs.get("aws_access_key_id") + aws_secret_access_key = session_kwargs.get("aws_secret_access_key") + aws_session_token = session_kwargs.get("aws_session_token") + region_name = session_kwargs.get("region_name") + profile_name = session_kwargs.get("profile_name") + + aio_session = get_session() + if profile_name is not None: + aio_session.set_config_variable("profile", profile_name) + if aws_access_key_id or aws_secret_access_key or aws_session_token: + aio_session.set_credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + if region_name is not None: + aio_session.set_config_variable("region", region_name) + return aio_session + + def create_session(self) -> AioSession: + """Create aiobotocore Session from connection and config.""" + if not self._conn: + self.log.info("No connection ID provided. Fallback on boto3 credential strategy") + return get_session() + elif not self.role_arn: + return self._basic_session + return self._get_session_with_assume_role() + + +class AwsBaseAsyncHook(AwsBaseHook): + """ + Interacts with AWS using aiobotocore asynchronously. + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default botocore behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default botocore configuration would be used (and must be + maintained on each worker node). + :param verify: Whether to verify SSL certificates. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param client_type: boto3.client client_type. Eg 's3', 'emr' etc + :param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc + :param config: Configuration for botocore client. + """ + + def get_async_session(self) -> AioSession: + """Get the underlying aiobotocore.session.AioSession(...).""" + return BaseAsyncSessionFactory( + conn=self.conn_config, region_name=self.region_name, config=self.config + ).create_session() + + async def get_client_async(self): + """Get the underlying aiobotocore client using aiobotocore session""" + return self.get_async_session().create_client( + self.client_type, + region_name=self.region_name, + verify=self.verify, + endpoint_url=self.conn_config.endpoint_url, + config=self.config, + ) diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py index 99d59a3c5aaa5..84e97be3806e0 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py +++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py @@ -16,12 +16,14 @@ # under the License. from __future__ import annotations +import asyncio import warnings from typing import Any, Sequence +import botocore.exceptions from botocore.exceptions import ClientError -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook class RedshiftHook(AwsBaseHook): @@ -200,3 +202,82 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifi return snapshot_status except self.get_conn().exceptions.ClusterSnapshotNotFoundFault: return None + + +class RedshiftAsyncHook(AwsBaseAsyncHook): + """Interact with AWS Redshift using aiobotocore library""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["client_type"] = "redshift" + super().__init__(*args, **kwargs) + + async def cluster_status(self, cluster_identifier: str, delete_operation: bool = False) -> dict[str, Any]: + """ + Connects to the AWS redshift cluster via aiobotocore and get the status + and returns the status of the cluster based on the cluster_identifier passed + + :param cluster_identifier: unique identifier of a cluster + :param delete_operation: whether the method has been called as part of delete cluster operation + """ + async with await self.get_client_async() as client: + try: + response = await client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster_state = ( + response["Clusters"][0]["ClusterStatus"] if response and response["Clusters"] else None + ) + return {"status": "success", "cluster_state": cluster_state} + except botocore.exceptions.ClientError as error: + if delete_operation and error.response.get("Error", {}).get("Code", "") == "ClusterNotFound": + return {"status": "success", "cluster_state": "cluster_not_found"} + return {"status": "error", "message": str(error)} + + async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.0) -> dict[str, Any]: + """ + Connects to the AWS redshift cluster via aiobotocore and + pause the cluster based on the cluster_identifier passed + + :param cluster_identifier: unique identifier of a cluster + :param poll_interval: polling period in seconds to check for the status + """ + try: + async with await self.get_client_async() as client: + response = await client.pause_cluster(ClusterIdentifier=cluster_identifier) + status = response["Cluster"]["ClusterStatus"] if response and response["Cluster"] else None + if status == "pausing": + flag = asyncio.Event() + while True: + expected_response = await asyncio.create_task( + self.get_cluster_status(cluster_identifier, "paused", flag) + ) + await asyncio.sleep(poll_interval) + if flag.is_set(): + return expected_response + return {"status": "error", "cluster_state": status} + except botocore.exceptions.ClientError as error: + return {"status": "error", "message": str(error)} + + async def get_cluster_status( + self, + cluster_identifier: str, + expected_state: str, + flag: asyncio.Event, + delete_operation: bool = False, + ) -> dict[str, Any]: + """ + check for expected Redshift cluster state + + :param cluster_identifier: unique identifier of a cluster + :param expected_state: expected_state example("available", "pausing", "paused"") + :param flag: asyncio even flag set true if success and if any error + :param delete_operation: whether the method has been called as part of delete cluster operation + """ + try: + response = await self.cluster_status(cluster_identifier, delete_operation=delete_operation) + if ("cluster_state" in response and response["cluster_state"] == expected_state) or response[ + "status" + ] == "error": + flag.set() + return response + except botocore.exceptions.ClientError as error: + flag.set() + return {"status": "error", "message": str(error)} diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 0771d67a370f7..01fdea153adf9 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -22,6 +22,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -447,6 +448,7 @@ class RedshiftPauseClusterOperator(BaseOperator): :param cluster_identifier: id of the AWS Redshift Cluster :param aws_conn_id: aws connection to use + :param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>= """ template_fields: Sequence[str] = ("cluster_identifier",) @@ -458,11 +460,15 @@ def __init__( *, cluster_identifier: str, aws_conn_id: str = "aws_default", + deferrable: bool = False, + poll_interval: int = 10, **kwargs, ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id + self.deferrable = deferrable + self.poll_interval = poll_interval # These parameters are added to address an issue with the boto3 API where the API # prematurely reports the cluster as available to receive requests. This causes the cluster # to reject initial attempts to pause the cluster despite reporting the correct state. @@ -472,18 +478,48 @@ def __init__( def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - while self._attempts >= 1: - try: - redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) - return - except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._attempts = self._attempts - 1 - - if self._attempts > 0: - self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) - time.sleep(self._attempt_interval) - else: - raise error + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=RedshiftClusterTrigger( + task_id=self.task_id, + poll_interval=self.poll_interval, + aws_conn_id=self.aws_conn_id, + cluster_identifier=self.cluster_identifier, + attempts=self._attempts, + operation_type="pause_cluster", + ), + method_name="execute_complete", + ) + else: + while self._attempts >= 1: + try: + redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) + return + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._attempts = self._attempts - 1 + + if self._attempts > 0: + self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) + time.sleep(self._attempt_interval) + else: + raise error + + def execute_complete(self, context: Context, event: Any = None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event: + if "status" in event and event["status"] == "error": + msg = f"{event['status']}: {event['message']}" + raise AirflowException(msg) + elif "status" in event and event["status"] == "success": + self.log.info("%s completed successfully.", self.task_id) + self.log.info("Paused cluster successfully") + else: + raise AirflowException("No event received from trigger") class RedshiftDeleteClusterOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/__init__.py b/airflow/providers/amazon/aws/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py new file mode 100644 index 0000000000000..fb5c36d375e4a --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class RedshiftClusterTrigger(BaseTrigger): + """AWS Redshift trigger""" + + def __init__( + self, + task_id: str, + aws_conn_id: str, + cluster_identifier: str, + operation_type: str, + attempts: int, + poll_interval: float = 5.0, + ): + super().__init__() + self.task_id = task_id + self.poll_interval = poll_interval + self.aws_conn_id = aws_conn_id + self.cluster_identifier = cluster_identifier + self.operation_type = operation_type + self.attempts = attempts + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger", + { + "task_id": self.task_id, + "poll_interval": self.poll_interval, + "aws_conn_id": self.aws_conn_id, + "cluster_identifier": self.cluster_identifier, + "attempts": self.attempts, + "operation_type": self.operation_type, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: + hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id) + while self.attempts >= 1: + self.attempts = self.attempts - 1 + try: + if self.operation_type == "pause_cluster": + response = await hook.pause_cluster( + cluster_identifier=self.cluster_identifier, + poll_interval=self.poll_interval, + ) + if response.get("status") == "success": + yield TriggerEvent(response) + else: + if self.attempts < 1: + yield TriggerEvent({"status": "error", "message": f"{self.task_id} failed"}) + else: + yield TriggerEvent(f"{self.operation_type} is not supported") + except Exception as e: + if self.attempts < 1: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index b73c293cf1183..431efca824e12 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -56,6 +56,7 @@ dependencies: - apache-airflow>=2.3.0 - apache-airflow-providers-common-sql>=1.3.1 - boto3>=1.24.0 + - asgiref # watchtower 3 has been released end Jan and introduced breaking change across the board that might # change logging behaviour: # https://github.com/kislyuk/watchtower/blob/develop/Changes.rst#changes-for-v300-2022-01-26 @@ -597,3 +598,6 @@ additional-extras: - name: pandas dependencies: - pandas>=0.17.1 + - name: aiobotocore + dependencies: + - aiobotocore>=2.1.1 diff --git a/docs/apache-airflow-providers-amazon/deferrable.rst b/docs/apache-airflow-providers-amazon/deferrable.rst new file mode 100644 index 0000000000000..9f5f94c5d73d2 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/deferrable.rst @@ -0,0 +1,33 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +======================== +AWS Deferrable Operators +======================== + +The AWS deferrable operators depends on ``aiobotocore>=2.1.1`` library. Unfortunately, currently we can't add this to +core AWS provider dependencies because of a conflicting version of ``botocore`` between ``aiobotocore`` and ``boto3``. +We have added ``aiobotocore`` as an addition dependency. So if you want to use AWS deferrable operator then you will have to +manage this by yourself. + +We have introduced an async hook to manage authentication between to the AWS services asynchronously. The +AWS async hook currently support the default botocore authentication mechanism i.e if Airflow connection is +not provided then provider will try to find the credential param in environment variable. If the Airflow connection is +provided then basic auth with secret-key/access-key-id/profile/token and arn-method should work. + +To use deferrable operator we have exposed ``deferrable`` param in those operators which support deferrable execution. +By default ``deferrable`` is set to ``False`` set this to ``True`` to run operator in asynchronous mode. diff --git a/docs/apache-airflow-providers-amazon/index.rst b/docs/apache-airflow-providers-amazon/index.rst index 8722392b14a71..a796325870301 100644 --- a/docs/apache-airflow-providers-amazon/index.rst +++ b/docs/apache-airflow-providers-amazon/index.rst @@ -27,6 +27,7 @@ Content Connection types Operators + Deferrable Operators Secrets backends Logging for Tasks diff --git a/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst b/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst index 9902cdb4c9a12..b85eddd2c3e29 100644 --- a/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst +++ b/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst @@ -65,8 +65,9 @@ To resume a 'paused' Amazon Redshift cluster you can use Pause an Amazon Redshift cluster ================================ -To pause an 'available' Amazon Redshift cluster you can use -:class:`RedshiftPauseClusterOperator ` +To pause an ``available`` Amazon Redshift cluster you can use +:class:`RedshiftPauseClusterOperator `. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True`` .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py :language: python diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 533255a084f79..a980d751396cf 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -16,6 +16,8 @@ adls afterall AgentKey aio +aiobotocore +AioSession Airbnb airbnb Airbyte diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index bf99bbad4a230..716156b8ad5af 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -19,6 +19,7 @@ "deps": [ "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.3.0", + "asgiref", "boto3>=1.24.0", "jsonpath_ng>=1.5.3", "mypy-boto3-appflow>=1.24.0", diff --git a/tests/providers/amazon/aws/deferrable/__init__.py b/tests/providers/amazon/aws/deferrable/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/amazon/aws/deferrable/hooks/__init__.py b/tests/providers/amazon/aws/deferrable/hooks/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/amazon/aws/deferrable/hooks/test_base_aws.py b/tests/providers/amazon/aws/deferrable/hooks/test_base_aws.py new file mode 100644 index 0000000000000..d50f6dcd16e02 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/hooks/test_base_aws.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import ANY + +import pytest + +from airflow.models.connection import Connection +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook +from tests.providers.amazon.aws.utils.compat import async_mock + +try: + from aiobotocore.credentials import AioCredentials +except ImportError: + pass + +pytest.importorskip("aiobotocore") + + +class TestAwsBaseAsyncHook: + @staticmethod + def compare_aio_cred(first, second): + if not type(first) == type(second): + return False + if first.access_key != second.access_key: + return False + if first.secret_key != second.secret_key: + return False + if first.method != second.method: + return False + if first.token != second.token: + return False + return True + + class Matcher: + def __init__(self, compare, obj): + self.compare = compare + self.obj = obj + + def __eq__(self, other): + return self.compare(self.obj, other) + + @pytest.mark.asyncio + @async_mock.patch("airflow.hooks.base.BaseHook.get_connection") + @async_mock.patch("aiobotocore.session.AioClientCreator.create_client") + async def test_get_client_async(self, mock_client, mock_get_connection): + """Check the connection credential passed while creating client""" + mock_get_connection.return_value = Connection( + conn_id="aws_default1", + extra="""{ + "aws_secret_access_key": "mock_aws_access_key", + "aws_access_key_id": "mock_aws_access_key_id", + "region_name": "us-east-2" + }""", + ) + + hook = AwsBaseAsyncHook(client_type="S3", aws_conn_id="aws_default1") + cred = await hook.get_async_session().get_credentials() + + # check credential have same values as intended + assert cred.__dict__ == { + "access_key": "mock_aws_access_key_id", + "method": "explicit", + "secret_key": "mock_aws_access_key", + "token": None, + } + + aio_cred = AioCredentials( + access_key="mock_aws_access_key_id", + method="explicit", + secret_key="mock_aws_access_key", + ) + + await (await hook.get_client_async())._coro + # Test the aiobotocore client created with right param + mock_client.assert_called_once_with( + service_name="S3", + region_name="us-east-2", + is_secure=True, + endpoint_url=None, + verify=None, + credentials=self.Matcher(self.compare_aio_cred, aio_cred), + scoped_config={}, + client_config=ANY, + api_version=None, + auth_token=None, + ) diff --git a/tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py b/tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py new file mode 100644 index 0000000000000..3cbc39cedabc9 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/hooks/test_redshift_cluster.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +import pytest +from botocore.exceptions import ClientError + +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook +from tests.providers.amazon.aws.utils.compat import async_mock + +pytest.importorskip("aiobotocore") + + +class TestRedshiftAsyncHook: + @pytest.mark.asyncio + @async_mock.patch("aiobotocore.client.AioBaseClient._make_api_call") + async def test_cluster_status(self, mock_make_api_call): + """Test that describe_clusters get called with correct param""" + hook = RedshiftAsyncHook(aws_conn_id="aws_default", client_type="redshift", resource_type="redshift") + await hook.cluster_status(cluster_identifier="redshift_cluster_1") + mock_make_api_call.assert_called_once_with( + "DescribeClusters", {"ClusterIdentifier": "redshift_cluster_1"} + ) + + @pytest.mark.asyncio + @async_mock.patch("aiobotocore.client.AioBaseClient._make_api_call") + async def test_pause_cluster(self, mock_make_api_call): + """Test that pause_cluster get called with correct param""" + hook = RedshiftAsyncHook(aws_conn_id="aws_default", client_type="redshift", resource_type="redshift") + await hook.pause_cluster(cluster_identifier="redshift_cluster_1") + mock_make_api_call.assert_called_once_with( + "PauseCluster", {"ClusterIdentifier": "redshift_cluster_1"} + ) + + @pytest.mark.asyncio + @async_mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async" + ) + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status") + async def test_get_cluster_status(self, cluster_status, mock_client): + """Test get_cluster_status async function with success response""" + flag = asyncio.Event() + cluster_status.return_value = {"status": "success", "cluster_state": "available"} + hook = RedshiftAsyncHook(aws_conn_id="aws_default") + result = await hook.get_cluster_status("redshift_cluster_1", "available", flag) + assert result == {"status": "success", "cluster_state": "available"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status") + async def test_get_cluster_status_exception(self, cluster_status): + """Test get_cluster_status async function with exception response""" + flag = asyncio.Event() + cluster_status.side_effect = ClientError( + { + "Error": { + "Code": "SomeServiceException", + "Message": "Details/context around the exception or error", + }, + }, + operation_name="redshift", + ) + hook = RedshiftAsyncHook(aws_conn_id="aws_default") + result = await hook.get_cluster_status("test-identifier", "available", flag) + assert result == { + "status": "error", + "message": "An error occurred (SomeServiceException) when calling the " + "redshift operation: Details/context around the exception or error", + } diff --git a/tests/providers/amazon/aws/deferrable/triggers/__init__.py b/tests/providers/amazon/aws/deferrable/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py new file mode 100644 index 0000000000000..fe34233895323 --- /dev/null +++ b/tests/providers/amazon/aws/deferrable/triggers/test_redshift_cluster.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib.util + +import pytest + +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, +) +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.compat import async_mock + +TASK_ID = "redshift_trigger_check" +POLLING_PERIOD_SECONDS = 1.0 + + +@pytest.mark.skipif(not bool(importlib.util.find_spec("aiobotocore")), reason="aiobotocore require") +class TestRedshiftClusterTrigger: + def test_pause_serialization(self): + """ + Asserts that the RedshiftClusterTrigger correctly serializes its arguments + and classpath. + """ + trigger = RedshiftClusterTrigger( + task_id=TASK_ID, + poll_interval=POLLING_PERIOD_SECONDS, + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + attempts=10, + operation_type="pause_cluster", + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger" + assert kwargs == { + "task_id": TASK_ID, + "poll_interval": POLLING_PERIOD_SECONDS, + "aws_conn_id": "test_redshift_conn_id", + "cluster_identifier": "mock_cluster_identifier", + "attempts": 10, + "operation_type": "pause_cluster", + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.pause_cluster") + async def test_pause_trigger_run(self, mock_pause_cluster): + """ + Test trigger event for the pause_cluster response + """ + trigger = RedshiftClusterTrigger( + task_id=TASK_ID, + poll_interval=POLLING_PERIOD_SECONDS, + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + attempts=1, + operation_type="pause_cluster", + ) + generator = trigger.run() + await generator.asend(None) + mock_pause_cluster.assert_called_once_with( + cluster_identifier="mock_cluster_identifier", poll_interval=1.0 + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.pause_cluster") + async def test_pause_trigger_failure(self, mock_pause_cluster): + """Test trigger event when pause cluster raise exception""" + mock_pause_cluster.side_effect = Exception("Test exception") + trigger = RedshiftClusterTrigger( + task_id=TASK_ID, + poll_interval=POLLING_PERIOD_SECONDS, + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + attempts=1, + operation_type="pause_cluster", + ) + task = [i async for i in trigger.run()] + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 71e73ea8023e9..e2d6a16888061 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -21,7 +21,7 @@ import boto3 import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.operators.redshift_cluster import ( RedshiftCreateClusterOperator, @@ -31,6 +31,7 @@ RedshiftPauseClusterOperator, RedshiftResumeClusterOperator, ) +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger class TestRedshiftCreateClusterOperator: @@ -311,6 +312,39 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): redshift_operator.execute(None) assert mock_conn.pause_cluster.call_count == 10 + def test_pause_cluster_deferrable_mode(self): + """Test Pause cluster operator with defer when deferrable param is true""" + + redshift_operator = RedshiftPauseClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", deferrable=True + ) + + with pytest.raises(TaskDeferred) as exc: + redshift_operator.execute(context=None) + + assert isinstance( + exc.value.trigger, RedshiftClusterTrigger + ), "Trigger is not a RedshiftClusterTrigger" + + def test_pause_cluster_execute_complete_success(self): + """Asserts that logging occurs as expected""" + task = RedshiftPauseClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", deferrable=True + ) + with mock.patch.object(task.log, "info") as mock_log_info: + task.execute_complete(context=None, event={"status": "success"}) + mock_log_info.assert_called_with("Paused cluster successfully") + + def test_pause_cluster_execute_complete_fail(self): + redshift_operator = RedshiftPauseClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", deferrable=True + ) + + with pytest.raises(AirflowException): + redshift_operator.execute_complete( + context=None, event={"status": "error", "message": "test failure message"} + ) + class TestDeleteClusterOperator: @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") diff --git a/tests/providers/amazon/aws/utils/compat.py b/tests/providers/amazon/aws/utils/compat.py new file mode 100644 index 0000000000000..9c03414adcac4 --- /dev/null +++ b/tests/providers/amazon/aws/utils/compat.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +__all__ = ["async_mock", "AsyncMock"] + +import sys + +if sys.version_info < (3, 8): + # For compatibility with Python 3.7 + from asynctest import mock as async_mock + + # ``asynctest.mock.CoroutineMock`` which provide compatibility not working well with autospec=True + # as result "TypeError: object MagicMock can't be used in 'await' expression" could be raised. + # Best solution in this case provide as spec actual awaitable object + # >>> from tests.providers.amazon.cloud.utils.compat import AsyncMock + # >>> from foo.bar import SpamEgg + # >>> mock_something = AsyncMock(SpamEgg) + from asynctest.mock import CoroutineMock as AsyncMock +else: + from unittest import mock as async_mock + from unittest.mock import AsyncMock