Skip to content

Commit

Permalink
Use base aws classes in Amazon SQS Operators/Sensors/Triggers (#36613)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Jan 5, 2024
1 parent 034e618 commit 16d16e2
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 298 deletions.
25 changes: 16 additions & 9 deletions airflow/providers/amazon/aws/operators/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from typing import TYPE_CHECKING, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class SqsPublishOperator(BaseOperator):
class SqsPublishOperator(AwsBaseOperator[SqsHook]):
"""
Publish a message to an Amazon SQS queue.
Expand All @@ -41,10 +42,20 @@ class SqsPublishOperator(BaseOperator):
:param delay_seconds: message delay (templated) (default: 1 second)
:param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None)
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param aws_conn_id: AWS connection id (default: aws_default)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = SqsHook
template_fields: Sequence[str] = aws_template_fields(
"sqs_queue",
"message_content",
"delay_seconds",
Expand All @@ -62,12 +73,10 @@ def __init__(
message_attributes: dict | None = None,
delay_seconds: int = 0,
message_group_id: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.message_content = message_content
self.delay_seconds = delay_seconds
self.message_attributes = message_attributes or {}
Expand All @@ -81,9 +90,7 @@ def execute(self, context: Context) -> dict:
:return: dict with information about the message sent
For details of the returned dict see :py:meth:`botocore.client.SQS.send_message`
"""
hook = SqsHook(aws_conn_id=self.aws_conn_id)

result = hook.send_message(
result = self.hook.send_message(
queue_url=self.sqs_queue,
message_body=self.message_content,
delay_seconds=self.delay_seconds,
Expand Down
34 changes: 21 additions & 13 deletions airflow/providers/amazon/aws/sensors/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@
"""Reads and then deletes the message from SQS queue."""
from __future__ import annotations

from functools import cached_property
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Collection, Sequence

from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.utils.context import Context
from datetime import timedelta


class SqsSensor(BaseSensorOperator):
class SqsSensor(AwsBaseSensor[SqsHook]):
"""
Get messages from an Amazon SQS queue and then delete the messages from the queue.
Expand All @@ -51,7 +51,6 @@ class SqsSensor(BaseSensorOperator):
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:SqsSensor`
:param aws_conn_id: AWS connection id
:param sqs_queue: The SQS queue url (templated)
:param max_messages: The maximum number of messages to retrieve for each poke (templated)
:param num_batches: The number of times the sensor will call the SQS API to receive messages (default: 1)
Expand All @@ -75,16 +74,27 @@ class SqsSensor(BaseSensorOperator):
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("sqs_queue", "max_messages", "message_filtering_config")
aws_hook_class = SqsHook
template_fields: Sequence[str] = aws_template_fields(
"sqs_queue", "max_messages", "message_filtering_config"
)

def __init__(
self,
*,
sqs_queue,
aws_conn_id: str = "aws_default",
max_messages: int = 5,
num_batches: int = 1,
wait_time_seconds: int = 1,
Expand All @@ -98,7 +108,6 @@ def __init__(
):
super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
self.num_batches = num_batches
self.wait_time_seconds = wait_time_seconds
Expand Down Expand Up @@ -135,6 +144,9 @@ def execute(self, context: Context) -> Any:
message_filtering_config=self.message_filtering_config,
delete_message_on_reception=self.delete_message_on_reception,
waiter_delay=int(self.poke_interval),
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.timeout),
Expand Down Expand Up @@ -220,7 +232,3 @@ def poke(self, context: Context):
def get_hook(self) -> SqsHook:
"""Create and return an SqsHook."""
return self.hook

@cached_property
def hook(self) -> SqsHook:
return SqsHook(aws_conn_id=self.aws_conn_id)
21 changes: 18 additions & 3 deletions airflow/providers/amazon/aws/triggers/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class SqsSensorTrigger(BaseTrigger):
def __init__(
self,
sqs_queue: str,
aws_conn_id: str = "aws_default",
aws_conn_id: str | None = "aws_default",
max_messages: int = 5,
num_batches: int = 1,
wait_time_seconds: int = 1,
Expand All @@ -69,9 +69,11 @@ def __init__(
message_filtering_config: Any = None,
delete_message_on_reception: bool = True,
waiter_delay: int = 60,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
self.num_batches = num_batches
self.wait_time_seconds = wait_time_seconds
Expand All @@ -82,6 +84,11 @@ def __init__(
self.message_filtering_config = message_filtering_config
self.waiter_delay = waiter_delay

self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
Expand All @@ -97,12 +104,20 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"message_filtering_match_values": self.message_filtering_match_values,
"message_filtering_config": self.message_filtering_config,
"waiter_delay": self.waiter_delay,
"region_name": self.region_name,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@property
def hook(self) -> SqsHook:
return SqsHook(aws_conn_id=self.aws_conn_id)
return SqsHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def poll_sqs(self, client: BaseAwsConnection) -> Collection:
"""
Expand Down
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/sqs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
111 changes: 65 additions & 46 deletions tests/providers/amazon/aws/operators/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,85 +17,104 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock
from unittest import mock

import pytest
from botocore.exceptions import ClientError
from moto import mock_sqs

from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.operators.sqs import SqsPublishOperator
from airflow.utils import timezone

DEFAULT_DATE = timezone.datetime(2019, 1, 1)

REGION_NAME = "eu-west-1"
QUEUE_NAME = "test-queue"
QUEUE_URL = f"https://{QUEUE_NAME}"

FIFO_QUEUE_NAME = "test-queue.fifo"
FIFO_QUEUE_URL = f"https://{FIFO_QUEUE_NAME}"


@pytest.fixture
def mocked_context():
return mock.MagicMock(name="FakeContext")


class TestSqsPublishOperator:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}

self.dag = DAG("test_dag_id", default_args=args)
self.operator = SqsPublishOperator(
task_id="test_task",
dag=self.dag,
sqs_queue=QUEUE_URL,
message_content="hello",
aws_conn_id="aws_default",
@pytest.fixture(autouse=True)
def setup_test_cases(self):
self.default_op_kwargs = {
"task_id": "test_task",
"message_content": "hello",
"aws_conn_id": None,
"region_name": REGION_NAME,
}
self.sqs_client = SqsHook(aws_conn_id=None, region_name=REGION_NAME).conn

def test_init(self):
self.default_op_kwargs.pop("aws_conn_id", None)
self.default_op_kwargs.pop("region_name", None)

op = SqsPublishOperator(sqs_queue=QUEUE_NAME, **self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = SqsPublishOperator(
sqs_queue=FIFO_QUEUE_NAME,
**self.default_op_kwargs,
aws_conn_id=None,
region_name=REGION_NAME,
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)

self.mock_context = MagicMock()
self.sqs_hook = SqsHook()
assert op.hook.aws_conn_id is None
assert op.hook._region_name == REGION_NAME
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

@mock_sqs
def test_execute_success(self):
self.sqs_hook.create_queue(QUEUE_NAME)
def test_execute_success(self, mocked_context):
self.sqs_client.create_queue(QueueName=QUEUE_NAME)

result = self.operator.execute(self.mock_context)
# Send SQS Message
op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=QUEUE_NAME)
result = op.execute(mocked_context)
assert "MD5OfMessageBody" in result
assert "MessageId" in result

message = self.sqs_hook.get_conn().receive_message(QueueUrl=QUEUE_URL)

# Validate message through moto
message = self.sqs_client.receive_message(QueueUrl=QUEUE_URL)
assert len(message["Messages"]) == 1
assert message["Messages"][0]["MessageId"] == result["MessageId"]
assert message["Messages"][0]["Body"] == "hello"

context_calls = []

assert self.mock_context["ti"].method_calls == context_calls, "context call should be same"

@mock_sqs
def test_execute_failure_fifo_queue(self):
self.operator.sqs_queue = FIFO_QUEUE_URL
self.sqs_hook.create_queue(FIFO_QUEUE_NAME, attributes={"FifoQueue": "true"})
with pytest.raises(ClientError) as ctx:
self.operator.execute(self.mock_context)
err_msg = (
"An error occurred (MissingParameter) when calling the SendMessage operation: The request must "
"contain the parameter MessageGroupId."
def test_execute_failure_fifo_queue(self, mocked_context):
self.sqs_client.create_queue(QueueName=FIFO_QUEUE_NAME, Attributes={"FifoQueue": "true"})

op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME)
error_message = (
"An error occurred \(MissingParameter\) when calling the SendMessage operation: "
"The request must contain the parameter MessageGroupId."
)
assert err_msg == str(ctx.value)
with pytest.raises(ClientError, match=error_message):
op.execute(mocked_context)

@mock_sqs
def test_execute_success_fifo_queue(self):
self.operator.sqs_queue = FIFO_QUEUE_URL
self.operator.message_group_id = "abc"
self.sqs_hook.create_queue(
FIFO_QUEUE_NAME, attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true"}
def test_execute_success_fifo_queue(self, mocked_context):
self.sqs_client.create_queue(
QueueName=FIFO_QUEUE_NAME, Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true"}
)
result = self.operator.execute(self.mock_context)

# Send SQS Message into the FIFO Queue
op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc")
result = op.execute(mocked_context)
assert "MD5OfMessageBody" in result
assert "MessageId" in result
message = self.sqs_hook.get_conn().receive_message(
QueueUrl=FIFO_QUEUE_URL, AttributeNames=["MessageGroupId"]
)

# Validate message through moto
message = self.sqs_client.receive_message(QueueUrl=FIFO_QUEUE_URL, AttributeNames=["MessageGroupId"])
assert len(message["Messages"]) == 1
assert message["Messages"][0]["MessageId"] == result["MessageId"]
assert message["Messages"][0]["Body"] == "hello"
Expand Down
Loading

0 comments on commit 16d16e2

Please sign in to comment.