Skip to content

Commit

Permalink
update schema filtering to use new condition logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kddejong committed Jul 31, 2024
1 parent 8f1d89a commit 152e8d8
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 18 deletions.
32 changes: 31 additions & 1 deletion src/cfnlint/context/conditions/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@

from __future__ import annotations

import itertools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterator

from sympy import Not, Or
from sympy.logic.boolalg import BooleanFunction
from sympy.logic.inference import satisfiable

from cfnlint.conditions._utils import get_hash
from cfnlint.context.conditions._condition import Condition
from cfnlint.context.conditions._utils import (
build_instance_from_scenario,
get_conditions_from_property,
)
from cfnlint.context.conditions.exceptions import Unsatisfiable

if TYPE_CHECKING:
Expand All @@ -25,6 +30,7 @@ class Conditions:
# Template level condition management
conditions: dict[str, Condition] = field(init=True, default_factory=dict)
cnf: BooleanFunction | None = field(init=True, default=None)
_max_scenarios: int = field(init=False, default=128)

@classmethod
def create_from_instance(
Expand Down Expand Up @@ -106,6 +112,30 @@ def evolve(self, status: dict[str, bool]) -> "Conditions":
cnf=cnf,
)

def _build_conditions(self, conditions: set[str]) -> Iterator["Conditions"]:
scenarios_attempted = 0
for product in itertools.product([True, False], repeat=len(conditions)):
params = dict(zip(conditions, product))
try:
yield self.evolve(params)
except Unsatisfiable:
pass

scenarios_attempted += 1
# On occassions people will use a lot of non-related conditions
# this is fail safe to limit the maximum number of responses
if scenarios_attempted >= self._max_scenarios:
return

Check warning on line 128 in src/cfnlint/context/conditions/_conditions.py

View check run for this annotation

Codecov / codecov/patch

src/cfnlint/context/conditions/_conditions.py#L128

Added line #L128 was not covered by tests

def evolve_from_instance(self, instance: Any) -> Iterator[tuple[Any, "Conditions"]]:

conditions = get_conditions_from_property(instance)

for scenario in self._build_conditions(conditions):
yield build_instance_from_scenario(
instance, scenario.status, is_root=True
), scenario

@property
def status(self) -> dict[str, bool]:
obj = {}
Expand Down
112 changes: 112 additions & 0 deletions src/cfnlint/context/conditions/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""

from __future__ import annotations

from typing import Any

from cfnlint.decode.node import Mark, dict_node, list_node
from cfnlint.helpers import is_function


def get_conditions_from_property(instance: Any, is_root: bool = True) -> set[str]:
"""
Gets the name of the conditions used directly inside the object.
We do not look at nested objects for conditions.
Args:
instance (Any): The object or listto process.
is_root (bool): If we are at the root of the object. Default: True.
Returns:
set[str]: The set of conditions used in the object or list.
"""
results: set[str] = set()
if isinstance(instance, list):
for v in instance:
results = results.union(get_conditions_from_property(v, is_root=False))
return results

fn_k, fn_v = is_function(instance)
if fn_k == "Fn::If":
if isinstance(fn_v, list) and len(fn_v) == 3:
if isinstance(fn_v[0], str):
results.add(fn_v[0])
results = results.union(
get_conditions_from_property(fn_v[1], is_root=is_root)
)
results = results.union(
get_conditions_from_property(fn_v[2], is_root=is_root)
)
return results
if is_root:
if isinstance(instance, dict):
for k, v in instance.items():
results = results.union(get_conditions_from_property(v, is_root=False))
return results


def build_instance_from_scenario(
instance: Any, scenario: dict[str, bool], is_root: bool = True
) -> Any:
"""
Get object values from a provided scenario.
This function recursively processes the provided object, resolving any
conditional logic (such as Fn::If) based on the given scenario.
Args:
instance (Any): The object or listto process.
scenario (dict): The scenario to use when resolving conditional logic.
is_root (bool): If we are at the root of the object. Default: True.
Returns:
dict or list or any: The processed object, with conditional logic resolved.
The return type can be a dictionary, list, or any other data type,
depending on the structure of the input object.
"""

if isinstance(instance, list):
new_list: list[Any] = list_node(
[],
getattr(instance, "start_mark", Mark(0, 0)),
getattr(instance, "end_mark", Mark(0, 0)),
)
for v in instance:
new_value = build_instance_from_scenario(v, scenario, is_root=False)
if new_value is not None:
new_list.append(new_value)
return new_list

if isinstance(instance, dict):
fn_k, fn_v = is_function(instance)
if fn_k == "Fn::If":
if isinstance(fn_v, list) and len(fn_v) == 3:
if isinstance(fn_v[0], str):
if_path = scenario.get(fn_v[0], None)
if if_path is not None:
new_value = build_instance_from_scenario(
fn_v[1] if if_path else fn_v[2], scenario, is_root
)
if new_value is not None:
return new_value
return None
return instance
if fn_k == "Ref" and fn_v == "AWS::NoValue":
return {} if is_root else None
if is_root:
new_obj: dict[str, Any] = dict_node(
{},
getattr(instance, "start_mark", Mark(0, 0)),
getattr(instance, "end_mark", Mark(0, 0)),
)
for k, v in instance.items():
new_value = build_instance_from_scenario(v, scenario, is_root=False)
if new_value is not None:
new_obj[k] = new_value
return new_obj

return instance
25 changes: 16 additions & 9 deletions src/cfnlint/jsonschema/_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Iterator, Sequence, Tuple

from cfnlint.helpers import FUNCTIONS, REGEX_DYN_REF, ToPy, ensure_list

Expand Down Expand Up @@ -107,7 +107,9 @@ def _filter_schemas(self, schema, validator: Validator) -> Tuple[Any, Any]:

return standard_schema, group_schema

def filter(self, validator: Any, instance: Any, schema: Any):
def filter(
self, validator: Any, instance: Any, schema: Any
) -> Iterator[tuple[Any, dict[str, Any], "Validator"]]:
# Lets validate dynamic references when appropriate
if validator.is_type(instance, "string") and self.validate_dynamic_references:
if REGEX_DYN_REF.findall(instance):
Expand All @@ -117,14 +119,14 @@ def filter(self, validator: Any, instance: Any, schema: Any):
p in set(FUNCTIONS) - set(["Fn::If"])
for p in validator.context.path.path
):
yield (instance, {"dynamicReference": schema})
yield (instance, {"dynamicReference": schema}, validator)
return
return

# if there are no functions then we don't need to worry
# about ref AWS::NoValue or If conditions
if not validator.context.functions:
yield instance, schema
yield instance, schema, validator
return

# dependencies, required, minProperties, maxProperties
Expand All @@ -133,9 +135,14 @@ def filter(self, validator: Any, instance: Any, schema: Any):
standard_schema, group_schema = self._filter_schemas(schema, validator)

if group_schema:
scenarios = validator.cfn.get_object_without_conditions(instance)
for scenario in scenarios:
yield (scenario.get("Object"), group_schema)
for (
scenario_instance,
scenario,
) in validator.context.conditions.evolve_from_instance(instance):
scenario_validator = validator.evolve(
context=validator.context.evolve(conditions=scenario)
)
yield (scenario_instance, group_schema, scenario_validator)

if validator.is_type(instance, "object"):
if len(instance) == 1:
Expand All @@ -145,10 +152,10 @@ def filter(self, validator: Any, instance: Any, schema: Any):
k_schema = {
k_py.py: standard_schema,
}
yield (instance, k_schema)
yield (instance, k_schema, validator)
return

yield (instance, standard_schema)
yield (instance, standard_schema, validator)

def evolve(self, **kwargs) -> "FunctionFilter":
"""
Expand Down
6 changes: 4 additions & 2 deletions src/cfnlint/jsonschema/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def iter_errors(self, instance: Any) -> ValidationResult:
self.resolver.push_scope(scope)
try:
# we need filter and apply schemas against the new instances
for _instance, _schema in self.function_filter.filter(
for _instance, _schema, _validator in self.function_filter.filter(
self, instance, schema
):
for k, v in _schema.items():
Expand All @@ -238,7 +238,9 @@ def iter_errors(self, instance: Any) -> ValidationResult:
continue

try:
for err in validator(self, v, _instance, _schema) or ():
for err in (
validator(_validator, v, _instance, _schema) or ()
):
msg = custom_msg(k, _schema) or err.message
if msg is not None:
err.message = msg
Expand Down
2 changes: 1 addition & 1 deletion src/cfnlint/template/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def get_value(value, scenario): # pylint: disable=R0911
result[key] = new_value
return result
if isinstance(obj, list):
result = list_node({}, obj.start_mark, obj.end_mark)
result = list_node([], obj.start_mark, obj.end_mark)

Check warning on line 965 in src/cfnlint/template/template.py

View check run for this annotation

Codecov / codecov/patch

src/cfnlint/template/template.py#L965

Added line #L965 was not covered by tests
for item in obj:
element = get_value(item, scenario)
if element is not None:
Expand Down
106 changes: 106 additions & 0 deletions test/unit/module/context/conditions/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,112 @@ def test_condition_status(current_status, new_status, expected):
assert context.conditions.status == expected


@pytest.mark.parametrize(
"current_status,instance,expected",
[
({}, {"Foo": "Bar"}, [({"Foo": "Bar"}, {})]),
(
{},
{"Fn::If": ["IsUsEast1", {"Foo": "Foo"}, {"Bar": "Bar"}]},
[
({"Foo": "Foo"}, {"IsUsEast1": True}),
({"Bar": "Bar"}, {"IsUsEast1": False}),
],
),
(
{
"IsUsEast1": True,
},
{"Fn::If": ["IsUsEast1", {"Foo": "Foo"}, {"Bar": "Bar"}]},
[
({"Foo": "Foo"}, {"IsUsEast1": True}),
],
),
(
{
"IsUsEast1": False,
},
{"Fn::If": ["IsUsEast1", {"Foo": "Foo"}, {"Bar": "Bar"}]},
[
({"Bar": "Bar"}, {"IsUsEast1": False}),
],
),
(
{},
{"Ref": "AWS::NoValue"},
[
({}, {}),
],
),
(
{},
[{"Foo": {"Fn::If": ["IsUsEast1", "Foo", "Bar"]}}],
[
([{"Foo": {"Fn::If": ["IsUsEast1", "Foo", "Bar"]}}], {}),
],
),
(
{},
[{"Fn::If": ["IsUsEast1", {"Foo": "Bar"}, {"Ref": "AWS::NoValue"}]}],
[
([{"Foo": "Bar"}], {"IsUsEast1": True}),
([], {"IsUsEast1": False}),
],
),
(
{"IsUsEast1": True, "IsProd": True},
{
"A": {"Fn::If": ["IsUsEast1AndProd", 1, 2]},
"B": {"Fn::If": ["IsAi", 10, 11]},
},
[
(
{"A": 1, "B": 10},
{
"IsUsEast1": True,
"IsProd": True,
"IsUsEast1AndProd": True,
"IsAi": True,
},
),
(
{"A": 1, "B": 11},
{
"IsUsEast1": True,
"IsProd": True,
"IsUsEast1AndProd": True,
"IsAi": False,
},
),
],
),
(
{},
{
"A": {"Fn::If": ["IsUsEast1AndProd", 1]},
},
[
(
{"A": {"Fn::If": ["IsUsEast1AndProd", 1]}},
{},
),
],
),
],
)
def test_evolve_from_instance(current_status, instance, expected):
cfn = Template(None, template(), regions=["us-east-1"])
context = create_context_for_template(cfn)

context = context.evolve(conditions=context.conditions.evolve(current_status))

results = list(context.conditions.evolve_from_instance(instance))
assert len(results) == len(expected)
for result, expected_result in zip(results, expected):
assert result[0] == expected_result[0]
assert result[1].status == expected_result[1]


def test_condition_failures():
with pytest.raises(ValueError):
Conditions.create_from_instance([], {})
6 changes: 5 additions & 1 deletion test/unit/module/jsonschema/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,8 @@ def test_filter(name, instance, schema, path, functions, expected, filter):
)
results = list(filter.filter(validator, instance, schema))

assert results == expected, f"For test {name} got {results!r}"
assert len(results) == len(expected), f"For test {name} got {len(results)} results"

for result, (exp_instance, exp_schema) in zip(results, expected):
assert result[0] == exp_instance, f"For test {name} got {result.instance!r}"
assert result[1] == exp_schema, f"For test {name} got {result.schema!r}"
Loading

0 comments on commit 152e8d8

Please sign in to comment.