Skip to content

Commit

Permalink
fix: exit serialization early on detection of a cycle (#78)
Browse files Browse the repository at this point in the history
* test: add cycle failure cases

* feat: catch serialization cycle

* refactor: rename amber data serializer internal vars

* refactor: use object id to identify visited

* refactor: use custom class for cleaner max depth repr

* refactor: remove unneeded list

* refactor: undo unneeded change
  • Loading branch information
iamogbz authored and Noah committed Dec 29, 2019
1 parent 9b173d5 commit 9f2f396
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 47 deletions.
126 changes: 79 additions & 47 deletions src/syrupy/serializers/amber.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@


class DataSerializer:
indent: str = " "
name_marker: str = "# name:"
divider: str = "---"
_indent: str = " "
_max_depth: int = 99
_marker_divider: str = "---"
_marker_name: str = "# name:"

class MarkerDepthMax:
def __repr__(self) -> str:
return "..."

@classmethod
def write_file(cls, filepath: str, snapshots: Dict[str, Dict[str, Any]]) -> None:
Expand All @@ -31,10 +36,10 @@ def write_file(cls, filepath: str, snapshots: Dict[str, Dict[str, Any]]) -> None
snapshot = snapshots[key]
snapshot_data = snapshot.get("data")
if snapshot_data is not None:
f.write(f"{cls.name_marker} {key}\n")
f.write(f"{cls._marker_name} {key}\n")
for data_line in snapshot_data.split("\n"):
f.write(f"{cls.indent}{data_line}\n")
f.write(f"{cls.divider}\n")
f.write(f"{cls._indent}{data_line}\n")
f.write(f"{cls._marker_divider}\n")

@classmethod
def read_file(cls, filepath: str) -> Dict[str, Dict[str, Any]]:
Expand All @@ -43,22 +48,22 @@ def read_file(cls, filepath: str) -> Dict[str, Dict[str, Any]]:
of snapshot name to raw data. This does not attempt any deserialization
of the snapshot data.
"""
name_marker_len = len(cls.name_marker)
indent_len = len(cls.indent)
name_marker_len = len(cls._marker_name)
indent_len = len(cls._indent)
snapshots = {}
test_name = None
snapshot_data = ""
try:
with open(filepath, "r") as f:
for line in f:
if line.startswith(cls.name_marker):
if line.startswith(cls._marker_name):
test_name = line[name_marker_len:-1].strip(" \n")
snapshot_data = ""
continue
elif test_name is not None:
if line.startswith(cls.indent):
if line.startswith(cls._indent):
snapshot_data += line[indent_len:]
elif line.startswith(cls.divider) and snapshot_data:
elif line.startswith(cls._marker_divider) and snapshot_data:
snapshots[test_name] = {"data": snapshot_data[:-1]}
except FileNotFoundError:
pass
Expand All @@ -81,79 +86,108 @@ def _sort_key(value: Any) -> Any:
return sorted(iterable, key=_sort_key)

@classmethod
def with_indent(cls, string: str, indent: int) -> str:
return f"{cls.indent * indent}{string}"
def with_indent(cls, string: str, depth: int) -> str:
return f"{cls._indent * depth}{string}"

@classmethod
def object_type(cls, data: "SerializableData") -> str:
return f"<class '{data.__class__.__name__}'>"

@classmethod
def serialize_string(cls, data: "SerializableData", indent: int = 0) -> str:
def serialize_string(
cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set()
) -> str:
if "\n" in data:
return (
cls.with_indent("'\n", indent)
cls.with_indent("'\n", depth)
+ str(data)
+ cls.with_indent("\n'", indent)
+ cls.with_indent("\n'", depth)
)
return cls.with_indent(repr(data), indent)
return cls.with_indent(repr(data), depth)

@classmethod
def serialize_number(cls, data: "SerializableData", indent: int = 0) -> str:
return cls.with_indent(repr(data), indent)
def serialize_number(
cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set()
) -> str:
return cls.with_indent(repr(data), depth)

@classmethod
def serialize_set(cls, data: "SerializableData", indent: int = 0) -> str:
def serialize_set(
cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set()
) -> str:
return (
cls.with_indent(f"{cls.object_type(data)} {{\n", indent)
+ "".join([f"{cls.serialize(d, indent + 1)},\n" for d in cls.sort(data)])
+ cls.with_indent("}", indent)
cls.with_indent(f"{cls.object_type(data)} {{\n", depth)
+ "".join(
f"{cls.serialize(d, depth=depth + 1, visited=visited)},\n"
for d in cls.sort(data)
)
+ cls.with_indent("}", depth)
)

@classmethod
def serialize_dict(cls, data: "SerializableData", indent: int = 0) -> str:
def serialize_dict(
cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set()
) -> str:
kwargs = dict(depth=depth + 1, visited=visited)
return (
cls.with_indent(f"{cls.object_type(data)} {{\n", indent)
cls.with_indent(f"{cls.object_type(data)} {{\n", depth)
+ "".join(
[
f"{serialized_key}: {serialized_value.lstrip(cls._indent)},\n"
for serialized_key, serialized_value in (
(
cls.serialize(key, indent + 1)
+ ": "
+ cls.serialize(data[key], indent + 1).lstrip(cls.indent)
+ ",\n"
cls.serialize(**dict(data=key, **kwargs)),
cls.serialize(**dict(data=data[key], **kwargs)),
)
for key in cls.sort(data.keys())
]
)
)
+ cls.with_indent("}", indent)
+ cls.with_indent("}", depth)
)

@classmethod
def serialize_iterable(cls, data: "SerializableData", indent: int = 0) -> str:
def serialize_iterable(
cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set()
) -> str:
open_paren, close_paren = next(
paren[1]
for paren in {list: "[]", tuple: "()", GeneratorType: "()"}.items()
if isinstance(data, paren[0])
)
return (
cls.with_indent(f"{cls.object_type(data)} {open_paren}\n", indent)
+ "".join([f"{cls.serialize(d, indent + 1)},\n" for d in data])
+ cls.with_indent(close_paren, indent)
cls.with_indent(f"{cls.object_type(data)} {open_paren}\n", depth)
+ "".join(
f"{cls.serialize(d, depth=depth + 1, visited=visited)},\n" for d in data
)
+ cls.with_indent(close_paren, depth)
)

@classmethod
def serialize(cls, data: "SerializableData", indent: int = 0) -> str:
def serialize_unknown(
cls, data: Any, *, depth: int = 0, visited: Set[Any] = set()
) -> str:
return cls.with_indent(repr(data), depth)

@classmethod
def serialize(
cls, data: "SerializableData", *, depth: int = 0, visited: Set[Any] = set()
) -> str:
data_id = id(data)
if depth > cls._max_depth or data_id in visited:
data = cls.MarkerDepthMax()

serialize_kwargs = dict(data=data, depth=depth, visited={*visited, data_id})
serialize_method = cls.serialize_unknown
if isinstance(data, str):
return cls.serialize_string(data, indent)
serialize_method = cls.serialize_string
elif isinstance(data, (int, float)):
return cls.serialize_number(data, indent)
serialize_method = cls.serialize_number
elif isinstance(data, (set, frozenset)):
return cls.serialize_set(data, indent)
serialize_method = cls.serialize_set
elif isinstance(data, dict):
return cls.serialize_dict(data, indent)
serialize_method = cls.serialize_dict
elif isinstance(data, (list, tuple, GeneratorType)):
return cls.serialize_iterable(data, indent)
return cls.with_indent(repr(data), indent)
serialize_method = cls.serialize_iterable
return serialize_method(**serialize_kwargs)


class AmberSnapshotSerializer(AbstractSnapshotSerializer):
Expand All @@ -162,12 +196,10 @@ class AmberSnapshotSerializer(AbstractSnapshotSerializer):
```
# name: test_name_1
data
data
---
# name: test_name_2
data
data
```
"""

Expand Down
16 changes: 16 additions & 0 deletions tests/__snapshots__/test_amber_serializer.ambr
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
# name: TestClass.test_name
'this is in a test class'
---
# name: test_cycle[cyclic0]
<class 'list'> [
1,
2,
3,
...,
]
---
# name: test_cycle[cyclic1]
<class 'dict'> {
'a': 1,
'b': 2,
'c': 3,
'd': ...,
}
---
# name: test_dict[actual0]
<class 'dict'> {
'a': <class 'dict'> {
Expand Down
12 changes: 12 additions & 0 deletions tests/test_amber_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def test_list(snapshot):
assert snapshot == [1, 2, "string", {"key": "value"}]


list_cycle = [1, 2, 3]
list_cycle.append(list_cycle)

dict_cycle = {"a": 1, "b": 2, "c": 3}
dict_cycle.update(d=dict_cycle)


@pytest.mark.parametrize("cyclic", [list_cycle, dict_cycle])
def test_cycle(cyclic, snapshot):
assert cyclic == snapshot


class TestClass:
def test_name(self, snapshot):
assert snapshot == "this is in a test class"

0 comments on commit 9f2f396

Please sign in to comment.