diff --git a/src/cfnlint/rules/resources/cloudfront/DistributionTargetOriginId.py b/src/cfnlint/rules/resources/cloudfront/DistributionTargetOriginId.py new file mode 100644 index 0000000000..39f0d963c5 --- /dev/null +++ b/src/cfnlint/rules/resources/cloudfront/DistributionTargetOriginId.py @@ -0,0 +1,58 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: MIT-0 +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + +from cfnlint.jsonschema import ValidationError, ValidationResult, Validator +from cfnlint.rules.helpers import get_value_from_path +from cfnlint.rules.jsonschema.CfnLintKeyword import CfnLintKeyword + + +class DistributionTargetOriginId(CfnLintKeyword): + id = "E3057" + shortdesc = "Validate that CloudFront TargetOriginId is a specified Origin" + description = ( + "CloudFront TargetOriginId has to map to an Origin Id that " + "is in the same DistributionConfig" + ) + source_url = "https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-cloudfront-distribution-defaultcachebehavior.html#cfn-cloudfront-distribution-defaultcachebehavior-targetoriginid" + tags = ["properties", "cloudfront"] + + def __init__(self): + """Init""" + super().__init__( + keywords=[ + "Resources/AWS::CloudFront::Distribution/Properties/DistributionConfig" + ] + ) + + def validate( + self, validator: Validator, _, instance: Any, schema: dict[str, Any] + ) -> ValidationResult: + + for cache_origin_id, cache_validator in get_value_from_path( + validator, instance, path=deque(["DefaultCacheBehavior", "TargetOriginId"]) + ): + if not validator.is_type(cache_origin_id, "string"): + continue + origin_ids = [] + for origin_id, _ in get_value_from_path( + cache_validator, instance, path=deque(["Origins", "*", "Id"]) + ): + if not validator.is_type(origin_id, "string"): + break + if origin_id == cache_origin_id: + break + origin_ids.append(origin_id) + else: + yield ValidationError( + message=f"{cache_origin_id!r} is not one of {origin_ids!r}", + rule=self, + path_override=cache_validator.context.path.path, + validator="enum", + ) diff --git a/test/unit/rules/resources/cloudfront/test_distribution_target_origin_id.py b/test/unit/rules/resources/cloudfront/test_distribution_target_origin_id.py new file mode 100644 index 0000000000..563364c880 --- /dev/null +++ b/test/unit/rules/resources/cloudfront/test_distribution_target_origin_id.py @@ -0,0 +1,131 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: MIT-0 +""" + +from collections import deque + +import pytest + +from cfnlint.jsonschema import ValidationError +from cfnlint.rules.resources.cloudfront.DistributionTargetOriginId import ( + DistributionTargetOriginId, +) + + +@pytest.fixture(scope="module") +def rule(): + rule = DistributionTargetOriginId() + yield rule + + +@pytest.mark.parametrize( + "instance,expected", + [ + ( + [], # wrong type should return no issues + [], + ), + ( + { + "DefaultCacheBehavior": { + "TargetOriginId": "origin-id", + }, + "Origins": [ + { + "Id": "origin-id", + }, + ], + }, + [], + ), + ( + { + "DefaultCacheBehavior": { + "TargetOriginId": "origin-id", + }, + "Origins": [ + { + "Id": "foo", + }, + { + "Id": "origin-id", + }, + { + "Id": "bar", + }, + ], + }, + [], + ), + ( + { + "DefaultCacheBehavior": { + "TargetOriginId": "origin-id", + }, + "Origins": [ + { + "Id": "foo", + }, + { + "Id": "bar", + }, + ], + }, + [ + ValidationError( + ("'origin-id' is not one of " "['foo', 'bar']"), + rule=DistributionTargetOriginId(), + path=deque([]), + validator="enum", + path_override=deque(["DefaultCacheBehavior", "TargetOriginId"]), + ) + ], + ), + ( + { + "DefaultCacheBehavior": { + "TargetOriginId": {"Ref": "MyParameter"}, + }, + "Origins": [ + { + "Id": "foo", + }, + { + "Id": "bar", + }, + ], + }, + [], + ), + ( + { + "DefaultCacheBehavior": { + "TargetOriginId": "origin-id", + }, + "Origins": [ + { + "Id": "foo", + }, + { + "Id": {"Ref": "MyParameter"}, + }, + { + "Id": "bar", + }, + ], + }, + [], + ), + ], +) +def test_validate(instance, expected, rule, validator): + errs = list(rule.validate(validator, "", instance, {})) + for err in errs: + print(err.path) + print(err.path_override) + print(err.validator) + print(err.schema_path) + print(err.message) + print(err.rule) + assert errs == expected, f"Expected {expected} got {errs}"