diff --git a/src/syrupy/serializers/amber.py b/src/syrupy/serializers/amber.py index f99f7a10..d0cfe27d 100644 --- a/src/syrupy/serializers/amber.py +++ b/src/syrupy/serializers/amber.py @@ -17,9 +17,14 @@ class DataSerializer: - indent: str = " " - name_marker: str = "# name:" - divider: str = "---" + _indent: str = " " + _max_depth: int = 99 + _marker_divider: str = "---" + _marker_name: str = "# name:" + + class MarkerDepthMax: + def __repr__(self) -> str: + return "..." @classmethod def write_file(cls, filepath: str, snapshots: Dict[str, Dict[str, Any]]) -> None: @@ -31,10 +36,10 @@ def write_file(cls, filepath: str, snapshots: Dict[str, Dict[str, Any]]) -> None snapshot = snapshots[key] snapshot_data = snapshot.get("data") if snapshot_data is not None: - f.write(f"{cls.name_marker} {key}\n") + f.write(f"{cls._marker_name} {key}\n") for data_line in snapshot_data.split("\n"): - f.write(f"{cls.indent}{data_line}\n") - f.write(f"{cls.divider}\n") + f.write(f"{cls._indent}{data_line}\n") + f.write(f"{cls._marker_divider}\n") @classmethod def read_file(cls, filepath: str) -> Dict[str, Dict[str, Any]]: @@ -43,22 +48,22 @@ def read_file(cls, filepath: str) -> Dict[str, Dict[str, Any]]: of snapshot name to raw data. This does not attempt any deserialization of the snapshot data. """ - name_marker_len = len(cls.name_marker) - indent_len = len(cls.indent) + name_marker_len = len(cls._marker_name) + indent_len = len(cls._indent) snapshots = {} test_name = None snapshot_data = "" try: with open(filepath, "r") as f: for line in f: - if line.startswith(cls.name_marker): + if line.startswith(cls._marker_name): test_name = line[name_marker_len:-1].strip(" \n") snapshot_data = "" continue elif test_name is not None: - if line.startswith(cls.indent): + if line.startswith(cls._indent): snapshot_data += line[indent_len:] - elif line.startswith(cls.divider) and snapshot_data: + elif line.startswith(cls._marker_divider) and snapshot_data: snapshots[test_name] = {"data": snapshot_data[:-1]} except FileNotFoundError: pass @@ -81,79 +86,108 @@ def _sort_key(value: Any) -> Any: return sorted(iterable, key=_sort_key) @classmethod - def with_indent(cls, string: str, indent: int) -> str: - return f"{cls.indent * indent}{string}" + def with_indent(cls, string: str, depth: int) -> str: + return f"{cls._indent * depth}{string}" @classmethod def object_type(cls, data: "SerializableData") -> str: return f"" @classmethod - def serialize_string(cls, data: "SerializableData", indent: int = 0) -> str: + def serialize_string( + cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set() + ) -> str: if "\n" in data: return ( - cls.with_indent("'\n", indent) + cls.with_indent("'\n", depth) + str(data) - + cls.with_indent("\n'", indent) + + cls.with_indent("\n'", depth) ) - return cls.with_indent(repr(data), indent) + return cls.with_indent(repr(data), depth) @classmethod - def serialize_number(cls, data: "SerializableData", indent: int = 0) -> str: - return cls.with_indent(repr(data), indent) + def serialize_number( + cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set() + ) -> str: + return cls.with_indent(repr(data), depth) @classmethod - def serialize_set(cls, data: "SerializableData", indent: int = 0) -> str: + def serialize_set( + cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set() + ) -> str: return ( - cls.with_indent(f"{cls.object_type(data)} {{\n", indent) - + "".join([f"{cls.serialize(d, indent + 1)},\n" for d in cls.sort(data)]) - + cls.with_indent("}", indent) + cls.with_indent(f"{cls.object_type(data)} {{\n", depth) + + "".join( + f"{cls.serialize(d, depth=depth + 1, visited=visited)},\n" + for d in cls.sort(data) + ) + + cls.with_indent("}", depth) ) @classmethod - def serialize_dict(cls, data: "SerializableData", indent: int = 0) -> str: + def serialize_dict( + cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set() + ) -> str: + kwargs = dict(depth=depth + 1, visited=visited) return ( - cls.with_indent(f"{cls.object_type(data)} {{\n", indent) + cls.with_indent(f"{cls.object_type(data)} {{\n", depth) + "".join( - [ + f"{serialized_key}: {serialized_value.lstrip(cls._indent)},\n" + for serialized_key, serialized_value in ( ( - cls.serialize(key, indent + 1) - + ": " - + cls.serialize(data[key], indent + 1).lstrip(cls.indent) - + ",\n" + cls.serialize(**dict(data=key, **kwargs)), + cls.serialize(**dict(data=data[key], **kwargs)), ) for key in cls.sort(data.keys()) - ] + ) ) - + cls.with_indent("}", indent) + + cls.with_indent("}", depth) ) @classmethod - def serialize_iterable(cls, data: "SerializableData", indent: int = 0) -> str: + def serialize_iterable( + cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set() + ) -> str: open_paren, close_paren = next( paren[1] for paren in {list: "[]", tuple: "()", GeneratorType: "()"}.items() if isinstance(data, paren[0]) ) return ( - cls.with_indent(f"{cls.object_type(data)} {open_paren}\n", indent) - + "".join([f"{cls.serialize(d, indent + 1)},\n" for d in data]) - + cls.with_indent(close_paren, indent) + cls.with_indent(f"{cls.object_type(data)} {open_paren}\n", depth) + + "".join( + f"{cls.serialize(d, depth=depth + 1, visited=visited)},\n" for d in data + ) + + cls.with_indent(close_paren, depth) ) @classmethod - def serialize(cls, data: "SerializableData", indent: int = 0) -> str: + def serialize_unknown( + cls, data: Any, *, depth: int = 0, visited: Set[Any] = set() + ) -> str: + return cls.with_indent(repr(data), depth) + + @classmethod + def serialize( + cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set() + ) -> str: + data_id = id(data) + if depth > cls._max_depth or data_id in visited: + data = cls.MarkerDepthMax() + + serialize_kwargs = dict(data=data, depth=depth, visited={*visited, data_id}) + serialize_method = cls.serialize_unknown if isinstance(data, str): - return cls.serialize_string(data, indent) + serialize_method = cls.serialize_string elif isinstance(data, (int, float)): - return cls.serialize_number(data, indent) + serialize_method = cls.serialize_number elif isinstance(data, (set, frozenset)): - return cls.serialize_set(data, indent) + serialize_method = cls.serialize_set elif isinstance(data, dict): - return cls.serialize_dict(data, indent) + serialize_method = cls.serialize_dict elif isinstance(data, (list, tuple, GeneratorType)): - return cls.serialize_iterable(data, indent) - return cls.with_indent(repr(data), indent) + serialize_method = cls.serialize_iterable + return serialize_method(**serialize_kwargs) class AmberSnapshotSerializer(AbstractSnapshotSerializer): @@ -162,12 +196,10 @@ class AmberSnapshotSerializer(AbstractSnapshotSerializer): ``` # name: test_name_1 - - data + data --- # name: test_name_2 - - data + data ``` """ diff --git a/tests/__snapshots__/test_amber_serializer.ambr b/tests/__snapshots__/test_amber_serializer.ambr index 8718fe23..b194fd48 100644 --- a/tests/__snapshots__/test_amber_serializer.ambr +++ b/tests/__snapshots__/test_amber_serializer.ambr @@ -1,6 +1,22 @@ # name: TestClass.test_name 'this is in a test class' --- +# name: test_cycle[cyclic0] + [ + 1, + 2, + 3, + ..., + ] +--- +# name: test_cycle[cyclic1] + { + 'a': 1, + 'b': 2, + 'c': 3, + 'd': ..., + } +--- # name: test_dict[actual0] { 'a': { diff --git a/tests/test_amber_serializer.py b/tests/test_amber_serializer.py index 9d09740f..b51f60e6 100644 --- a/tests/test_amber_serializer.py +++ b/tests/test_amber_serializer.py @@ -72,6 +72,18 @@ def test_list(snapshot): assert snapshot == [1, 2, "string", {"key": "value"}] +list_cycle = [1, 2, 3] +list_cycle.append(list_cycle) + +dict_cycle = {"a": 1, "b": 2, "c": 3} +dict_cycle.update(d=dict_cycle) + + +@pytest.mark.parametrize("cyclic", [list_cycle, dict_cycle]) +def test_cycle(cyclic, snapshot): + assert cyclic == snapshot + + class TestClass: def test_name(self, snapshot): assert snapshot == "this is in a test class"