diff --git a/snowfakery/parse_recipe_yaml.py b/snowfakery/parse_recipe_yaml.py index 520d91c8..556af422 100644 --- a/snowfakery/parse_recipe_yaml.py +++ b/snowfakery/parse_recipe_yaml.py @@ -233,7 +233,7 @@ def parse_count_expression(yaml_sobj: Dict, sobj_def: Dict, context: ParseContex def include_macro( - name: str, context: ParseContext + name: str, context: ParseContext, parent_macros=() ) -> Tuple[List[FieldFactory], List[TemplateLike]]: macro = context.macros.get(name) if not macro: @@ -241,22 +241,36 @@ def include_macro( f"Cannot find macro named {name}", **context.line_num() ) parsed_macro = parse_element( - macro, "macro", {}, {"fields": Dict, "friends": List}, context + macro, "macro", {}, {"fields": Dict, "friends": List, "include": str}, context ) fields = parsed_macro.fields or {} friends = parsed_macro.friends or [] - return parse_fields(fields, context), parse_friends(friends, context) + fields, friends = parse_fields(fields, context), parse_friends(friends, context) + if name in parent_macros: + idx = parent_macros.index(name) + raise exc.DataGenError( + f"Macro `{name}` calls `{'` which calls `'.join(parent_macros[idx+1:])}` which calls `{name}`", + **context.line_num(macro), + ) + parse_inclusions(macro, fields, friends, context, parent_macros + (name,)) + return fields, friends def parse_inclusions( - yaml_sobj: Dict, fields: List, friends: List, context: ParseContext + yaml_sobj: Dict, + fields: List, + friends: List, + context: ParseContext, + parent_macros=(), ) -> None: inclusions: Iterable[str] = [ x.strip() for x in yaml_sobj.get("include", "").split(",") ] inclusions = filter(None, inclusions) for inclusion in inclusions: - include_fields, include_friends = include_macro(inclusion, context) + include_fields, include_friends = include_macro( + inclusion, context, parent_macros + ) fields.extend(include_fields) friends.extend(include_friends) diff --git a/tests/macros-include-macros.yml b/tests/macros-include-macros.yml new file mode 100644 index 00000000..aed45edd --- /dev/null +++ b/tests/macros-include-macros.yml @@ -0,0 +1,11 @@ +- macro: foo + fields: + foobar: FOOBAR + +- macro: bar + include: foo + fields: + barbar: BARBAR + +- object: Bar + include: bar diff --git a/tests/macros-recurse-unbounded.yml b/tests/macros-recurse-unbounded.yml new file mode 100644 index 00000000..46db350e --- /dev/null +++ b/tests/macros-recurse-unbounded.yml @@ -0,0 +1,17 @@ +- macro: baz + include: bar + fields: + foobar: FOOBAR + +- macro: foo + include: baz + fields: + foobar: FOOBAR + +- macro: bar + include: foo + fields: + barbar: BARBAR + +- object: Bar + include: bar diff --git a/tests/test_macros.py b/tests/test_macros.py index e5a7a1da..8d85da72 100644 --- a/tests/test_macros.py +++ b/tests/test_macros.py @@ -1,8 +1,10 @@ import unittest from unittest import mock from io import StringIO +import pytest from snowfakery.data_generator import generate +import snowfakery.data_gen_exceptions as exc write_row_path = "snowfakery.output_streams.DebugOutputStream.write_row" @@ -77,3 +79,44 @@ def test_friend_includes_and_references(self, write_row): assert write_row.mock_calls[0] == mock.call("foo", {"id": 1}) assert write_row.mock_calls[1][1][0] == "bar" assert write_row.mock_calls[1][1][1]["myfoo"].id == 1 + + @mock.patch("snowfakery.output_streams.DebugOutputStream.write_row") + def test_macros_include_macros(self, write_row): + yaml = """ + - macro: foo + fields: + foobar: FOOBAR + + - macro: bar + include: foo + fields: + barbar: BARBAR + + - object: Bar + include: bar + """ + generate(StringIO(yaml)) + assert write_row.mock_calls[0] == mock.call( + "Bar", {"id": 1, "barbar": "BARBAR", "foobar": "FOOBAR"} + ) + + @mock.patch("snowfakery.output_streams.DebugOutputStream.write_row") + def test_macros_include_themselves(self, write_row): + yaml = """ + - macro: foo + include: bar + fields: + foobar: FOOBAR + + - macro: bar + include: foo + fields: + barbar: BARBAR + + - object: Bar + include: bar + """ + with pytest.raises(exc.DataGenError) as e: + generate(StringIO(yaml)) + assert "foo" in str(e.value) + assert "bar" in str(e.value)