From 0449fe3e2cd676fff96d0f11341bae1831208428 Mon Sep 17 00:00:00 2001
From: Kevin DeJong <kddejong@amazon.com>
Date: Tue, 25 Jun 2024 17:41:22 -0700
Subject: [PATCH] Greatly simplify findinmap resolution (#3406)

---
 src/cfnlint/jsonschema/_resolvers_cfn.py      | 133 ++++++------------
 .../module/jsonschema/test_resolvers_cfn.py   |  65 +++++++++
 test/unit/rules/functions/test_find_in_map.py |  17 ++-
 3 files changed, 125 insertions(+), 90 deletions(-)

diff --git a/src/cfnlint/jsonschema/_resolvers_cfn.py b/src/cfnlint/jsonschema/_resolvers_cfn.py
index d5c1eb3028..b2a86f3311 100644
--- a/src/cfnlint/jsonschema/_resolvers_cfn.py
+++ b/src/cfnlint/jsonschema/_resolvers_cfn.py
@@ -14,7 +14,6 @@
 from cfnlint.helpers import AVAILABILITY_ZONES, REGEX_SUB_PARAMETERS
 from cfnlint.jsonschema import ValidationError, Validator
 from cfnlint.jsonschema._typing import ResolutionResult
-from cfnlint.jsonschema._utils import equal
 
 
 def unresolvable(validator: Validator, instance: Any) -> ResolutionResult:
@@ -55,94 +54,50 @@ def find_in_map(validator: Validator, instance: Any) -> ResolutionResult:
                     ), None
                     default_value = value
 
-    for map_name, map_v, _ in validator.resolve_value(instance[0]):
-        if not validator.is_type(map_name, "string"):
-            continue
-        for top_level_key, top_v, _ in validator.resolve_value(instance[1]):
-            if validator.is_type(top_level_key, "integer"):
-                top_level_key = str(top_level_key)
-            if not validator.is_type(top_level_key, "string"):
-                continue
-            for second_level_key, second_v, err in validator.resolve_value(instance[2]):
-                if validator.is_type(second_level_key, "integer"):
-                    second_level_key = str(second_level_key)
-                if not validator.is_type(second_level_key, "string"):
-                    continue
-                try:
-                    mappings = list(validator.context.mappings.keys())
-                    if not default_value and all(
-                        not (equal(map_name, each)) for each in mappings
-                    ):
-                        yield None, map_v.evolve(
-                            context=map_v.context.evolve(
-                                path=map_v.context.path.evolve(value_path=deque([0])),
-                            ),
-                        ), ValidationError(
-                            f"{map_name!r} is not one of {mappings!r}", path=[0]
-                        )
-                        continue
-
-                    top_level_keys = list(
-                        validator.context.mappings[map_name].keys.keys()
-                    )
-                    if not default_value and all(
-                        not (equal(top_level_key, each)) for each in top_level_keys
-                    ):
-                        yield None, top_v.evolve(
-                            context=top_v.context.evolve(
-                                path=top_v.context.path.evolve(value_path=deque([1])),
-                            ),
-                        ), ValidationError(
-                            f"{top_level_key!r} is not one of {top_level_keys!r}",
-                            path=[1],
-                        )
-                        continue
-
-                    second_level_keys = list(
-                        validator.context.mappings[map_name]
-                        .keys[top_level_key]
-                        .keys.keys()
-                    )
-                    if not default_value and all(
-                        not (equal(second_level_key, each))
-                        for each in second_level_keys
-                    ):
-                        yield None, second_v.evolve(
-                            context=second_v.context.evolve(
-                                path=second_v.context.path.evolve(
-                                    value_path=deque([2])
-                                ),
-                            ),
-                        ), ValidationError(
-                            f"{second_level_key!r} is not one of {second_level_keys!r}",
-                            path=[2],
-                        )
-                        continue
-
-                    for value in validator.context.mappings[map_name].find_in_map(
-                        top_level_key,
-                        second_level_key,
-                    ):
-                        yield (
-                            value,
-                            validator.evolve(
-                                context=validator.context.evolve(
-                                    path=validator.context.path.evolve(
-                                        value_path=deque(
-                                            [
-                                                "Mappings",
-                                                map_name,
-                                                top_level_key,
-                                                second_level_key,
-                                            ]
-                                        )
-                                    )
-                                )
-                            ),
-                            None,
-                        )
-                except KeyError:
-                    pass
+    if (
+        validator.is_type(instance[0], "string")
+        and (
+            validator.is_type(instance[1], "string")
+            or validator.is_type(instance[1], "integer")
+        )
+        and validator.is_type(instance[2], "string")
+    ):
+        map = validator.context.mappings.get(instance[0])
+        if not map:
+            if not default_value:
+                yield None, validator, ValidationError(
+                    (
+                        f"{instance[0]!r} is not one of "
+                        f"{list(validator.context.mappings.keys())!r}"
+                    ),
+                    path=deque([0]),
+                )
+            return
+
+        top_key = map.keys.get(instance[1])
+        if not top_key:
+            if not default_value:
+                yield None, validator, ValidationError(
+                    (
+                        f"{instance[1]!r} is not one of "
+                        f"{list(map.keys.keys())!r} for "
+                        f"mapping {instance[0]!r}"
+                    ),
+                    path=deque([1]),
+                )
+            return
+
+        value = top_key.keys.get(instance[2])
+        if not value:
+            if not default_value:
+                yield default_value, validator, ValidationError(
+                    (
+                        f"{instance[2]!r} is not one of "
+                        f"{list(top_key.keys.keys())!r} for mapping "
+                        f"{instance[0]!r} and key {instance[1]!r}"
+                    ),
+                    path=deque([2]),
+                )
 
 
 def get_azs(validator: Validator, instance: Any) -> ResolutionResult:
diff --git a/test/unit/module/jsonschema/test_resolvers_cfn.py b/test/unit/module/jsonschema/test_resolvers_cfn.py
index 88757f7e8a..b9f097bd7b 100644
--- a/test/unit/module/jsonschema/test_resolvers_cfn.py
+++ b/test/unit/module/jsonschema/test_resolvers_cfn.py
@@ -8,6 +8,7 @@
 import pytest
 
 from cfnlint.context.context import Context, Map
+from cfnlint.jsonschema import ValidationError
 from cfnlint.jsonschema.validators import CfnTemplateValidator
 
 
@@ -19,6 +20,10 @@ def _resolve(name, instance, expected_results, **kwargs):
     for i, (instance, v, errors) in enumerate(resolutions):
         assert instance == expected_results[i][0]
         assert v.context.path.value_path == expected_results[i][1]
+        if errors:
+            print(errors.validator)
+            print(errors.path)
+            print(errors.schema_path)
         assert errors == expected_results[i][2]
 
 
@@ -216,6 +221,66 @@ def test_invalid_functions(name, instance, response):
             {"Fn::FindInMap": ["foo", "bar", "value", {"DefaultValue": "default"}]},
             [("default", deque([4, "DefaultValue"]), None)],
         ),
+        (
+            "Valid FindInMap with a bad mapping",
+            {"Fn::FindInMap": ["bar", "first", "second"]},
+            [
+                (
+                    None,
+                    deque([]),
+                    ValidationError(
+                        ("'bar' is not one of ['foo']"),
+                        path=deque(["Fn::FindInMap", 0]),
+                    ),
+                )
+            ],
+        ),
+        (
+            "Valid FindInMap with a bad mapping and default",
+            {"Fn::FindInMap": ["bar", "first", "second", {"DefaultValue": "default"}]},
+            [("default", deque([4, "DefaultValue"]), None)],
+        ),
+        (
+            "Valid FindInMap with a bad top key",
+            {"Fn::FindInMap": ["foo", "second", "first"]},
+            [
+                (
+                    None,
+                    deque([]),
+                    ValidationError(
+                        ("'second' is not one of ['first'] for " "mapping 'foo'"),
+                        path=deque(["Fn::FindInMap", 1]),
+                    ),
+                )
+            ],
+        ),
+        (
+            "Valid FindInMap with a bad top key and default",
+            {"Fn::FindInMap": ["foo", "second", "first", {"DefaultValue": "default"}]},
+            [("default", deque([4, "DefaultValue"]), None)],
+        ),
+        (
+            "Valid FindInMap with a bad third key",
+            {"Fn::FindInMap": ["foo", "first", "third"]},
+            [
+                (
+                    None,
+                    deque([]),
+                    ValidationError(
+                        (
+                            "'third' is not one of ['second'] for "
+                            "mapping 'foo' and key 'first'"
+                        ),
+                        path=deque(["Fn::FindInMap", 2]),
+                    ),
+                )
+            ],
+        ),
+        (
+            "Valid FindInMap with a bad second key and default",
+            {"Fn::FindInMap": ["foo", "first", "third", {"DefaultValue": "default"}]},
+            [("default", deque([4, "DefaultValue"]), None)],
+        ),
         (
             "Valid Sub with a resolvable values",
             {"Fn::Sub": ["${a}-${b}", {"a": "foo", "b": "bar"}]},
diff --git a/test/unit/rules/functions/test_find_in_map.py b/test/unit/rules/functions/test_find_in_map.py
index 3524b8b893..77fb167d6c 100644
--- a/test/unit/rules/functions/test_find_in_map.py
+++ b/test/unit/rules/functions/test_find_in_map.py
@@ -26,6 +26,12 @@ def cfn():
     return Template(
         "",
         {
+            "Parameters": {
+                "MyParameter": {
+                    "Type": "String",
+                    "AllowedValues": ["A", "B", "C"],
+                }
+            },
             "Resources": {"MyResource": Resource({"Type": "AWS::SSM::Parameter"})},
             "Mappings": {"A": {"B": {"C": "Value"}}},
         },
@@ -155,12 +161,20 @@ def context(cfn):
             [ValidationError("Foo")],
             [
                 ValidationError(
-                    "'C' is not one of ['B']",
+                    "'C' is not one of ['B'] for mapping 'A'",
                     path=deque(["Fn::FindInMap", 1]),
                     schema_path=deque([]),
                 ),
             ],
         ),
+        (
+            "Valid Fn::FindInMap as the Ref could work",
+            {"Fn::FindInMap": ["A", {"Ref": "MyParameter"}, "C"]},
+            {"type": "string"},
+            {"transforms": Transforms(["AWS::LanguageExtensions"])},
+            [],
+            [],
+        ),
         (
             "Valid Fn::FindInMap with a Ref to AWS::NoValue",
             {
@@ -201,4 +215,5 @@ def test_validate(
         ref_mock.assert_not_called()
     else:
         assert ref_mock.call_count == len(ref_mock_values) or 1
+
     assert errs == expected, f"Test {name!r} got {errs!r}"