Skip to content

Commit

Permalink
Use base aws classes in Amazon SNS Operators (#36615)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Jan 5, 2024
1 parent 1cc9fe1 commit 034e618
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 41 deletions.
31 changes: 21 additions & 10 deletions airflow/providers/amazon/aws/operators/sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,46 @@

from typing import TYPE_CHECKING, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sns import SnsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class SnsPublishOperator(BaseOperator):
class SnsPublishOperator(AwsBaseOperator[SnsHook]):
"""
Publish a message to Amazon SNS.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SnsPublishOperator`
:param aws_conn_id: aws connection to use
:param target_arn: either a TopicArn or an EndpointArn
:param message: the default message you want to send (templated)
:param subject: the message subject you want to send (templated)
:param message_attributes: the message attributes you want to send as a flat dict (data type will be
determined automatically)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("target_arn", "message", "subject", "message_attributes", "aws_conn_id")
template_ext: Sequence[str] = ()
aws_hook_class = SnsHook
template_fields: Sequence[str] = aws_template_fields(
"target_arn",
"message",
"subject",
"message_attributes",
)
template_fields_renderers = {"message_attributes": "json"}

def __init__(
Expand All @@ -54,19 +69,15 @@ def __init__(
message: str,
subject: str | None = None,
message_attributes: dict | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.target_arn = target_arn
self.message = message
self.subject = subject
self.message_attributes = message_attributes
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
sns = SnsHook(aws_conn_id=self.aws_conn_id)

self.log.info(
"Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s",
self.target_arn,
Expand All @@ -76,7 +87,7 @@ def execute(self, context: Context):
self.message,
)

return sns.publish_to_target(
return self.hook.publish_to_target(
target_arn=self.target_arn,
message=self.message,
subject=self.subject,
Expand Down
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/sns.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
67 changes: 36 additions & 31 deletions tests/providers/amazon/aws/operators/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from unittest import mock

import pytest

from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator

TASK_ID = "sns_publish_job"
Expand All @@ -30,44 +32,47 @@


class TestSnsPublishOperator:
@pytest.fixture(autouse=True)
def setup_test_cases(self):
self.default_op_kwargs = {
"task_id": TASK_ID,
"target_arn": TARGET_ARN,
"message": MESSAGE,
"subject": SUBJECT,
"message_attributes": MESSAGE_ATTRIBUTES,
}

def test_init(self):
# Given / When
operator = SnsPublishOperator(
task_id=TASK_ID,
op = SnsPublishOperator(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = SnsPublishOperator(
**self.default_op_kwargs,
aws_conn_id=AWS_CONN_ID,
target_arn=TARGET_ARN,
message=MESSAGE,
subject=SUBJECT,
message_attributes=MESSAGE_ATTRIBUTES,
region_name="us-west-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == AWS_CONN_ID
assert op.hook._region_name == "us-west-1"
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

# Then
assert TASK_ID == operator.task_id
assert AWS_CONN_ID == operator.aws_conn_id
assert TARGET_ARN == operator.target_arn
assert MESSAGE == operator.message
assert SUBJECT == operator.subject
assert MESSAGE_ATTRIBUTES == operator.message_attributes

@mock.patch("airflow.providers.amazon.aws.operators.sns.SnsHook")
def test_execute(self, mock_hook):
# Given
@mock.patch.object(SnsPublishOperator, "hook")
def test_execute(self, mocked_hook):
hook_response = {"MessageId": "foobar"}
mocked_hook.publish_to_target.return_value = hook_response

hook_instance = mock_hook.return_value
hook_instance.publish_to_target.return_value = hook_response
op = SnsPublishOperator(**self.default_op_kwargs)
assert op.execute({}) == hook_response

operator = SnsPublishOperator(
task_id=TASK_ID,
aws_conn_id=AWS_CONN_ID,
target_arn=TARGET_ARN,
mocked_hook.publish_to_target.assert_called_once_with(
message=MESSAGE,
subject=SUBJECT,
message_attributes=MESSAGE_ATTRIBUTES,
subject=SUBJECT,
target_arn=TARGET_ARN,
)

# When
result = operator.execute(None)

# Then
assert hook_response == result

0 comments on commit 034e618

Please sign in to comment.