diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py index 6b16dc074156c..18e3255c6fe78 100644 --- a/airflow/providers/amazon/aws/operators/sns.py +++ b/airflow/providers/amazon/aws/operators/sns.py @@ -20,14 +20,15 @@ from typing import TYPE_CHECKING, Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.sns import SnsHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class SnsPublishOperator(BaseOperator): +class SnsPublishOperator(AwsBaseOperator[SnsHook]): """ Publish a message to Amazon SNS. @@ -35,16 +36,30 @@ class SnsPublishOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SnsPublishOperator` - :param aws_conn_id: aws connection to use :param target_arn: either a TopicArn or an EndpointArn :param message: the default message you want to send (templated) :param subject: the message subject you want to send (templated) :param message_attributes: the message attributes you want to send as a flat dict (data type will be determined automatically) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("target_arn", "message", "subject", "message_attributes", "aws_conn_id") - template_ext: Sequence[str] = () + aws_hook_class = SnsHook + template_fields: Sequence[str] = aws_template_fields( + "target_arn", + "message", + "subject", + "message_attributes", + ) template_fields_renderers = {"message_attributes": "json"} def __init__( @@ -54,7 +69,6 @@ def __init__( message: str, subject: str | None = None, message_attributes: dict | None = None, - aws_conn_id: str = "aws_default", **kwargs, ): super().__init__(**kwargs) @@ -62,11 +76,8 @@ def __init__( self.message = message self.subject = subject self.message_attributes = message_attributes - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - sns = SnsHook(aws_conn_id=self.aws_conn_id) - self.log.info( "Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s", self.target_arn, @@ -76,7 +87,7 @@ def execute(self, context: Context): self.message, ) - return sns.publish_to_target( + return self.hook.publish_to_target( target_arn=self.target_arn, message=self.message, subject=self.subject, diff --git a/docs/apache-airflow-providers-amazon/operators/sns.rst b/docs/apache-airflow-providers-amazon/operators/sns.rst index 903b8d7630b4a..e589e38f89b07 100644 --- a/docs/apache-airflow-providers-amazon/operators/sns.rst +++ b/docs/apache-airflow-providers-amazon/operators/sns.rst @@ -32,6 +32,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_sns.py b/tests/providers/amazon/aws/operators/test_sns.py index 89f0d6d26636d..780bc7eade3ea 100644 --- a/tests/providers/amazon/aws/operators/test_sns.py +++ b/tests/providers/amazon/aws/operators/test_sns.py @@ -19,6 +19,8 @@ from unittest import mock +import pytest + from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator TASK_ID = "sns_publish_job" @@ -30,44 +32,47 @@ class TestSnsPublishOperator: + @pytest.fixture(autouse=True) + def setup_test_cases(self): + self.default_op_kwargs = { + "task_id": TASK_ID, + "target_arn": TARGET_ARN, + "message": MESSAGE, + "subject": SUBJECT, + "message_attributes": MESSAGE_ATTRIBUTES, + } + def test_init(self): - # Given / When - operator = SnsPublishOperator( - task_id=TASK_ID, + op = SnsPublishOperator(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = SnsPublishOperator( + **self.default_op_kwargs, aws_conn_id=AWS_CONN_ID, - target_arn=TARGET_ARN, - message=MESSAGE, - subject=SUBJECT, - message_attributes=MESSAGE_ATTRIBUTES, + region_name="us-west-1", + verify="/spam/egg.pem", + botocore_config={"read_timeout": 42}, ) + assert op.hook.aws_conn_id == AWS_CONN_ID + assert op.hook._region_name == "us-west-1" + assert op.hook._verify == "/spam/egg.pem" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 - # Then - assert TASK_ID == operator.task_id - assert AWS_CONN_ID == operator.aws_conn_id - assert TARGET_ARN == operator.target_arn - assert MESSAGE == operator.message - assert SUBJECT == operator.subject - assert MESSAGE_ATTRIBUTES == operator.message_attributes - - @mock.patch("airflow.providers.amazon.aws.operators.sns.SnsHook") - def test_execute(self, mock_hook): - # Given + @mock.patch.object(SnsPublishOperator, "hook") + def test_execute(self, mocked_hook): hook_response = {"MessageId": "foobar"} + mocked_hook.publish_to_target.return_value = hook_response - hook_instance = mock_hook.return_value - hook_instance.publish_to_target.return_value = hook_response + op = SnsPublishOperator(**self.default_op_kwargs) + assert op.execute({}) == hook_response - operator = SnsPublishOperator( - task_id=TASK_ID, - aws_conn_id=AWS_CONN_ID, - target_arn=TARGET_ARN, + mocked_hook.publish_to_target.assert_called_once_with( message=MESSAGE, - subject=SUBJECT, message_attributes=MESSAGE_ATTRIBUTES, + subject=SUBJECT, + target_arn=TARGET_ARN, ) - - # When - result = operator.execute(None) - - # Then - assert hook_response == result