Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable mode in RedshiftPauseClusterOperator #28850

Merged
merged 8 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,30 @@ jobs:
run: breeze ci fix-ownership
if: always()

tests-aws-async-provider:
timeout-minutes: 50
name: "Pytest for AWS Async Provider"
runs-on: "${{needs.build-info.outputs.runs-on}}"
needs: [build-info, wait-for-ci-images]
if: needs.build-info.outputs.run-tests == 'true'
Comment on lines +804 to +809
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it will run even if no changes in AWS Provider

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, what we should do here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should we run only on Amazon changes, I guess @potiuk might suggest as original autor of selective check for providers

steps:
- name: Cleanup repo
shell: bash
run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*"
- name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )"
uses: actions/checkout@v3
with:
persist-credentials: false
- name: "Prepare breeze & CI image"
uses: ./.github/actions/prepare_breeze_and_image
- name: "Run AWS Async Test"
run: "breeze shell \
'pip install aiobotocore>=2.1.1 && pytest /opt/airflow/tests/providers/amazon/aws/deferrable'"
Comment on lines +820 to +822
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All other settings would not apply for this run, include collect warnings.
I do not know is it required right now or not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, but I'm also not sure if that require right now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess when this PR would be ready for merge it would be solved?

- name: "Post Tests"
uses: ./.github/actions/post_tests
- name: "Fix ownership"
run: breeze ci fix-ownership
if: always()

tests-helm:
timeout-minutes: 80
Expand Down
131 changes: 129 additions & 2 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@

from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.exceptions import (
AirflowException,
AirflowNotFoundException,
)
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
Expand All @@ -62,7 +65,6 @@

BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])


if TYPE_CHECKING:
from airflow.models.connection import Connection # Avoid circular imports.

Expand Down Expand Up @@ -877,3 +879,128 @@ def _parse_s3_config(config_file_name: str, config_format: str | None = "boto",
config_format=config_format,
profile=profile,
)


try:
import aiobotocore.credentials
from aiobotocore.session import AioSession, get_session
except ImportError:
pass


class BaseAsyncSessionFactory(BaseSessionFactory):
"""
Base AWS Session Factory class to handle aiobotocore session creation.
It currently, handles ENV, AWS secret key and STS client method ``assume_role``
provided in Airflow connection
"""

async def get_role_credentials(self) -> dict:
"""Get the role_arn, method credentials from connection details and get the role credentials detail"""
async with self._basic_session.create_client("sts", region_name=self.region_name) as client:
response = await client.assume_role(
RoleArn=self.role_arn,
RoleSessionName=self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),
**self.conn.assume_role_kwargs,
)
return response["Credentials"]

async def _get_refresh_credentials(self) -> dict[str, Any]:
self.log.debug("Refreshing credentials")
assume_role_method = self.conn.assume_role_method
if assume_role_method != "assume_role":
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")

credentials = await self.get_role_credentials()

expiry_time = credentials["Expiration"].isoformat()
self.log.debug("New credentials expiry_time: %s", expiry_time)
credentials = {
"access_key": credentials.get("AccessKeyId"),
"secret_key": credentials.get("SecretAccessKey"),
"token": credentials.get("SessionToken"),
"expiry_time": expiry_time,
}
return credentials

def _get_session_with_assume_role(self) -> AioSession:

assume_role_method = self.conn.assume_role_method
if assume_role_method != "assume_role":
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")

credentials = aiobotocore.credentials.AioRefreshableCredentials.create_from_metadata(
metadata=self._get_refresh_credentials(),
refresh_using=self._get_refresh_credentials,
method="sts-assume-role",
)

session = aiobotocore.session.get_session()
session._credentials = credentials
return session

@cached_property
def _basic_session(self) -> AioSession:
"""Cached property with basic aiobotocore.session.AioSession."""
session_kwargs = self.conn.session_kwargs
aws_access_key_id = session_kwargs.get("aws_access_key_id")
aws_secret_access_key = session_kwargs.get("aws_secret_access_key")
aws_session_token = session_kwargs.get("aws_session_token")
region_name = session_kwargs.get("region_name")
profile_name = session_kwargs.get("profile_name")

aio_session = get_session()
if profile_name is not None:
aio_session.set_config_variable("profile", profile_name)
if aws_access_key_id or aws_secret_access_key or aws_session_token:
aio_session.set_credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
if region_name is not None:
aio_session.set_config_variable("region", region_name)
return aio_session

def create_session(self) -> AioSession:
"""Create aiobotocore Session from connection and config."""
if not self._conn:
self.log.info("No connection ID provided. Fallback on boto3 credential strategy")
return get_session()
elif not self.role_arn:
return self._basic_session
return self._get_session_with_assume_role()


class AwsBaseAsyncHook(AwsBaseHook):
"""
Interacts with AWS using aiobotocore asynchronously.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default botocore behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default botocore configuration would be used (and must be
maintained on each worker node).
:param verify: Whether to verify SSL certificates.
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
:param config: Configuration for botocore client.
"""

def get_async_session(self) -> AioSession:
"""Get the underlying aiobotocore.session.AioSession(...)."""
return BaseAsyncSessionFactory(
conn=self.conn_config, region_name=self.region_name, config=self.config
).create_session()

async def get_client_async(self):
"""Get the underlying aiobotocore client using aiobotocore session"""
return self.get_async_session().create_client(
self.client_type,
region_name=self.region_name,
verify=self.verify,
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
)
83 changes: 82 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Sequence

import botocore.exceptions
from botocore.exceptions import ClientError

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook


class RedshiftHook(AwsBaseHook):
Expand Down Expand Up @@ -200,3 +202,82 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifi
return snapshot_status
except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
return None


class RedshiftAsyncHook(AwsBaseAsyncHook):
"""Interact with AWS Redshift using aiobotocore library"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)

async def cluster_status(self, cluster_identifier: str, delete_operation: bool = False) -> dict[str, Any]:
"""
Connects to the AWS redshift cluster via aiobotocore and get the status
and returns the status of the cluster based on the cluster_identifier passed
:param cluster_identifier: unique identifier of a cluster
:param delete_operation: whether the method has been called as part of delete cluster operation
"""
async with await self.get_client_async() as client:
try:
response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
cluster_state = (
response["Clusters"][0]["ClusterStatus"] if response and response["Clusters"] else None
)
return {"status": "success", "cluster_state": cluster_state}
except botocore.exceptions.ClientError as error:
if delete_operation and error.response.get("Error", {}).get("Code", "") == "ClusterNotFound":
return {"status": "success", "cluster_state": "cluster_not_found"}
return {"status": "error", "message": str(error)}

async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.0) -> dict[str, Any]:
"""
Connects to the AWS redshift cluster via aiobotocore and
pause the cluster based on the cluster_identifier passed
:param cluster_identifier: unique identifier of a cluster
:param poll_interval: polling period in seconds to check for the status
"""
try:
async with await self.get_client_async() as client:
response = await client.pause_cluster(ClusterIdentifier=cluster_identifier)
status = response["Cluster"]["ClusterStatus"] if response and response["Cluster"] else None
if status == "pausing":
flag = asyncio.Event()
while True:
expected_response = await asyncio.create_task(
self.get_cluster_status(cluster_identifier, "paused", flag)
)
await asyncio.sleep(poll_interval)
if flag.is_set():
return expected_response
return {"status": "error", "cluster_state": status}
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error)}

async def get_cluster_status(
self,
cluster_identifier: str,
expected_state: str,
flag: asyncio.Event,
delete_operation: bool = False,
) -> dict[str, Any]:
"""
check for expected Redshift cluster state
:param cluster_identifier: unique identifier of a cluster
:param expected_state: expected_state example("available", "pausing", "paused"")
:param flag: asyncio even flag set true if success and if any error
:param delete_operation: whether the method has been called as part of delete cluster operation
"""
try:
response = await self.cluster_status(cluster_identifier, delete_operation=delete_operation)
if ("cluster_state" in response and response["cluster_state"] == expected_state) or response[
"status"
] == "error":
flag.set()
return response
except botocore.exceptions.ClientError as error:
flag.set()
return {"status": "error", "message": str(error)}
60 changes: 48 additions & 12 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -447,6 +448,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
:param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -458,11 +460,15 @@ def __init__(
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
Comment on lines +463 to +464
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering how we would resolve in case if we want use more generic parameters wait_for_completion as well as waiter_delay and waiter_max_attempts.

@ferruzzi @vincbeck Are any plans for migrate other AWS Operators to use waiters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are any plans for migrate other AWS Operators to use waiters?

Personally, I think I'd like to see all operators which implement a wait_for_completion block to use a waiter (either an official one or a custom one) but it's a matter of time. I am working on moving the EMR operators over as I get time but it's not a huge priority (yet?) and maybe it's better suited to a community project with a discussion to track the progress like we've done in the past?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first part, I don't know offhand but @syedahsn is working on deferrable stuff on our end and might have a better answer there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance thatbotocore will have native asyncio support 🤔 Or it is just some kind of surprise 🤣

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't speak for the botocore team on that, I don't have any contact with that team, so I wouldn't know any better than you on that one. I'm not aware of anything in that regard but I wouldn't really know. Syed is working on researching deferrable operators to start making the full suite of AWS operators deferrable though, so out of our team I'd say he's the closest to an expert on aiobotocore and related stuff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been requested for a looong time: boto/botocore#458
We reached out to them and they said Botocore will eventually support async by default, but it isn't coming any time soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we do not have native support soon.
Just hope that point for some inside 🤣

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our team is currently working on implementing deferrable operators as well, and we're taking a slightly different approach. Rather than create async counterparts for all AWS hooks, we are planning to create an async_conn property similar to the existing conn property. The async_conn property will return a ClientCreatorContext object that can be used to create an aiobotocore client that can make async calls to the boto3 API. With this approach, we reduce a lot of complexity and code duplication by not having separate async hooks for every service. Additionally, we are focusing on using waiters to perform the polling operations so that most of the triggers will be in a standard format. The aiobotocore library has an AIOWaiter class that make asynchronous calls to the boto3 API which makes it very easy to implement most waiters. This also saves us from having to write custom describe calls for every service. We are currently reviewing the design, but will have a PR for the community shortly.

**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
Expand All @@ -472,18 +478,48 @@ def __init__(
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)

while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
poll_interval=self.poll_interval,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
else:
raise AirflowException("No event received from trigger")


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
17 changes: 17 additions & 0 deletions airflow/providers/amazon/aws/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading