Skip to content

Commit

Permalink
Change Deferrable implementation for RedshiftPauseClusterOperator to …
Browse files Browse the repository at this point in the history
…follow standard (#30853)

* Change base_aws.py to support async_conn
* Add async custom waiter support in get_waiter, and base_waiter.py
* Add Deferrable mode to RedshiftCreateClusterOperator
* Add RedshiftCreateClusterTrigger and unit test
* Add README.md for writing Triggers for AMPP
* Add Deferrable mode to Redshift Pause Cluster Operator
* Add logging to deferrable waiter
* Add check for failure early
  • Loading branch information
syedahsn authored May 23, 2023
1 parent c082aec commit 0b7c095
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 42 deletions.
69 changes: 33 additions & 36 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException
Expand All @@ -25,6 +26,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -510,7 +512,9 @@ class RedshiftPauseClusterOperator(BaseOperator):
:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
:param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
:param deferrable: Run operator in the deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
:param max_attempts: Maximum number of attempts to poll the cluster
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -524,64 +528,57 @@ def __init__(
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
max_attempts: int = 15,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.max_attempts = max_attempts
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# These parameters are used to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
self._attempts = 10
self._remaining_attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
while self._remaining_attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._remaining_attempts = self._remaining_attempts - 1

if self._remaining_attempts > 0:
self.log.error(
"Unable to pause cluster. %d attempts remaining.", self._remaining_attempts
)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
trigger=RedshiftPauseClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60),
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error pausing cluster: {event}")
else:
raise AirflowException("No event received from trigger")
self.log.info("Paused cluster successfully")
return


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
73 changes: 73 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

from botocore.exceptions import WaiterError

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -137,3 +140,73 @@ async def run(self):
},
)
yield TriggerEvent({"status": "success", "message": "Cluster Created"})


class RedshiftPauseClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftPauseClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `paused` state.
:param cluster_identifier: A unique identifier for the cluster.
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_identifier: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_identifier = cluster_identifier
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger",
{
"cluster_identifier": self.cluster_identifier,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

@cached_property
def hook(self) -> RedshiftHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)

async def run(self):
async with self.hook.async_conn as client:
attempt = 0
waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client)
while attempt < int(self.max_attempts):
attempt = attempt + 1
try:
await waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": int(self.poll_interval),
"MaxAttempts": 1,
},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
yield TriggerEvent({"status": "failure", "message": f"Pause Cluster Failed: {error}"})
break
self.log.info(
"Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"]
)
await asyncio.sleep(int(self.poll_interval))
if attempt >= int(self.max_attempts):
yield TriggerEvent(
{"status": "failure", "message": "Pause Cluster Failed - max attempts reached."}
)
else:
yield TriggerEvent({"status": "success", "message": "Cluster paused"})
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/redshift.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"cluster_paused": {
"operation": "DescribeClusters",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "pathAll",
"argument": "Clusters[].ClusterStatus",
"expected": "paused",
"state": "success"
},
{
"matcher": "error",
"argument": "Clusters[].ClusterStatus",
"expected": "ClusterNotFound",
"state": "retry"
},
{
"matcher": "pathAny",
"argument": "Clusters[].ClusterStatus",
"expected": "deleting",
"state": "failure"
}
]
}
}
}
14 changes: 9 additions & 5 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftPauseClusterTrigger,
)


class TestRedshiftCreateClusterOperator:
Expand Down Expand Up @@ -389,9 +392,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 10

def test_pause_cluster_deferrable_mode(self):
@mock.patch.object(RedshiftHook, "get_conn")
def test_pause_cluster_deferrable_mode(self, mock_get_conn):
"""Test Pause cluster operator with defer when deferrable param is true"""

mock_get_conn().pause_cluster.return_value = True
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", deferrable=True
)
Expand All @@ -400,8 +404,8 @@ def test_pause_cluster_deferrable_mode(self):
redshift_operator.execute(context=None)

assert isinstance(
exc.value.trigger, RedshiftClusterTrigger
), "Trigger is not a RedshiftClusterTrigger"
exc.value.trigger, RedshiftPauseClusterTrigger
), "Trigger is not a RedshiftPauseClusterTrigger"

def test_pause_cluster_execute_complete_success(self):
"""Asserts that logging occurs as expected"""
Expand Down
Loading

0 comments on commit 0b7c095

Please sign in to comment.