diff --git a/aws_lambda_powertools/utilities/parser/models/sns.py b/aws_lambda_powertools/utilities/parser/models/sns.py index e329162e5c8..1b095fde2c4 100644 --- a/aws_lambda_powertools/utilities/parser/models/sns.py +++ b/aws_lambda_powertools/utilities/parser/models/sns.py @@ -31,8 +31,11 @@ class SnsNotificationModel(BaseModel): def check_sqs_protocol(cls, values): sqs_rewritten_keys = ("UnsubscribeURL", "SigningCertURL") if any(key in sqs_rewritten_keys for key in values): - values["UnsubscribeUrl"] = values.pop("UnsubscribeURL") - values["SigningCertUrl"] = values.pop("SigningCertURL") + # The sentinel value 'None' forces the validator to fail with + # ValidatorError instead of KeyError when the key is missing from + # the SQS payload + values["UnsubscribeUrl"] = values.pop("UnsubscribeURL", None) + values["SigningCertUrl"] = values.pop("SigningCertURL", None) return values diff --git a/tests/functional/parser/test_sns.py b/tests/functional/parser/test_sns.py index 81158a4419e..b0d9ff69a9b 100644 --- a/tests/functional/parser/test_sns.py +++ b/tests/functional/parser/test_sns.py @@ -1,3 +1,4 @@ +import json from typing import Any, List import pytest @@ -103,3 +104,29 @@ def handle_sns_sqs_json_body(event: List[MySnsBusiness], _: LambdaContext): def test_handle_sns_sqs_trigger_event_json_body(): # noqa: F811 event_dict = load_event("snsSqsEvent.json") handle_sns_sqs_json_body(event_dict, LambdaContext()) + + +def test_handle_sns_sqs_trigger_event_json_body_missing_signing_cert_url(): + # GIVEN an event is tampered with a missing SigningCertURL + event_dict = load_event("snsSqsEvent.json") + payload = json.loads(event_dict["Records"][0]["body"]) + payload.pop("SigningCertURL") + event_dict["Records"][0]["body"] = json.dumps(payload) + + # WHEN parsing the payload + # THEN raise a ValidationError error + with pytest.raises(ValidationError): + handle_sns_sqs_json_body(event_dict, LambdaContext()) + + +def test_handle_sns_sqs_trigger_event_json_body_missing_unsubscribe_url(): + # GIVEN an event is tampered with a missing UnsubscribeURL + event_dict = load_event("snsSqsEvent.json") + payload = json.loads(event_dict["Records"][0]["body"]) + payload.pop("UnsubscribeURL") + event_dict["Records"][0]["body"] = json.dumps(payload) + + # WHEN parsing the payload + # THEN raise a ValidationError error + with pytest.raises(ValidationError): + handle_sns_sqs_json_body(event_dict, LambdaContext())