Skip to content

Commit

Permalink
refactor: write performance improvements, api clarity (#645)
Browse files Browse the repository at this point in the history
* fix: group snapshot writes by extension class

* refactor: rename PyTestLocation.filename to .basename

BREAKING CHANGE: PyTestLocation.filename has been renamed to .basename

* refactor: add test_location kwarg to get_snapshot_name

* refactor: get_snapshot_name is now static as a classmethod

* refactor: remove pre and post read/write hooks

BREAKING CHANGE: Pre and post read/write hooks have been removed without replacement to make internal refactor simpler. Please open a GitHub issue if you have a use case for these hooks.

* refactor: rename Fossil to Collection

BREAKING CHANGE: The term 'fossil' has been replaced by the clearer term 'collection'.

* refactor: pass test_location to read_snapshot

* refactor: remove singular write_snapshot method

* refactor: dirname property to method

* refactor: pass test_location to discover_snapshots

* refactor: remove usage of self.test_location

* refactor: make write_snapshot a classmethod

* refactor: do not instantiate extension with test_location

BREAKING CHANGE: Numerous instance methods have been refactored as classmethods.
  • Loading branch information
noahnu authored and Noah Negin-Ulster committed Dec 30, 2022
1 parent 2fdfb10 commit 1038c30
Show file tree
Hide file tree
Showing 22 changed files with 399 additions and 351 deletions.
24 changes: 24 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ Fill in the relevant sections, clearly linking the issue the change is attemping

`debugpy` is installed in local development. A VSCode launch config is provided. Run `inv test -v -d` to enable the debugger (`-d` for debug). It'll then wait for you to attach your VSCode debugging client.

#### Debugging Performance Issues

You can run `inv benchmark` to run the full benchmark suite. Alternatively, write a test file, e.g.:

```py
# test_performance.py
import pytest
import os

SIZE = int(os.environ.get("SIZE", 1000))

@pytest.mark.parametrize("x", range(SIZE))
def test_performance(x, snapshot):
assert x == snapshot
```

and then run:

```sh
SIZE=1000 python -m cProfile -s cumtime -m pytest test_performance.py --snapshot-update -s > profile.log
```

See the cProfile docs for metric sorting options.

## Styleguides

### Commit Messages
Expand Down
15 changes: 11 additions & 4 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __post_init__(self) -> None:
def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
) -> "AbstractSyrupyExtension":
return extension_class(test_location=self.test_location)
return extension_class()

@property
def extension(self) -> "AbstractSyrupyExtension":
Expand Down Expand Up @@ -238,8 +238,12 @@ def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(index=self.index)
snapshot_name = self.extension.get_snapshot_name(index=self.index)
snapshot_location = self.extension.get_location(
test_location=self.test_location, index=self.index
)
snapshot_name = self.extension.get_snapshot_name(
test_location=self.test_location, index=self.index
)
snapshot_data: Optional["SerializedData"] = None
serialized_data: Optional["SerializedData"] = None
matches = False
Expand All @@ -264,6 +268,7 @@ def _assert(self, data: "SerializableData") -> bool:
if not matches and self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
Expand Down Expand Up @@ -299,7 +304,9 @@ def _post_assert(self) -> None:
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(
index=index, session_id=str(id(self.session))
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
)
except SnapshotDoesNotExist:
return None
4 changes: 2 additions & 2 deletions src/syrupy/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SNAPSHOT_DIRNAME = "__snapshots__"
SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot fossil"
SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot fossil"
SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot collection"
SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot collection"

EXIT_STATUS_FAIL_UNUSED = 1

Expand Down
52 changes: 26 additions & 26 deletions src/syrupy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SnapshotUnknown(Snapshot):


@dataclass
class SnapshotFossil:
class SnapshotCollection:
"""A collection of snapshots at a save location"""

location: str
Expand All @@ -54,8 +54,8 @@ def add(self, snapshot: "Snapshot") -> None:
if snapshot.name != SNAPSHOT_EMPTY_FOSSIL_KEY:
self.remove(SNAPSHOT_EMPTY_FOSSIL_KEY)

def merge(self, snapshot_fossil: "SnapshotFossil") -> None:
for snapshot in snapshot_fossil:
def merge(self, snapshot_collection: "SnapshotCollection") -> None:
for snapshot in snapshot_collection:
self.add(snapshot)

def remove(self, snapshot_name: str) -> None:
Expand All @@ -69,8 +69,8 @@ def __iter__(self) -> Iterator["Snapshot"]:


@dataclass
class SnapshotEmptyFossil(SnapshotFossil):
"""This is a saved fossil that is known to be empty and thus can be removed"""
class SnapshotEmptyCollection(SnapshotCollection):
"""This is a saved collection that is known to be empty and thus can be removed"""

_snapshots: Dict[str, "Snapshot"] = field(
default_factory=lambda: {SnapshotEmpty().name: SnapshotEmpty()}
Expand All @@ -82,42 +82,42 @@ def has_snapshots(self) -> bool:


@dataclass
class SnapshotUnknownFossil(SnapshotFossil):
"""This is a saved fossil that is unclaimed by any extension currently in use"""
class SnapshotUnknownCollection(SnapshotCollection):
"""This is a saved collection that is unclaimed by any extension currently in use"""

_snapshots: Dict[str, "Snapshot"] = field(
default_factory=lambda: {SnapshotUnknown().name: SnapshotUnknown()}
)


@dataclass
class SnapshotFossils:
_snapshot_fossils: Dict[str, "SnapshotFossil"] = field(default_factory=dict)
class SnapshotCollections:
_snapshot_collections: Dict[str, "SnapshotCollection"] = field(default_factory=dict)

def get(self, location: str) -> Optional["SnapshotFossil"]:
return self._snapshot_fossils.get(location)
def get(self, location: str) -> Optional["SnapshotCollection"]:
return self._snapshot_collections.get(location)

def add(self, snapshot_fossil: "SnapshotFossil") -> None:
self._snapshot_fossils[snapshot_fossil.location] = snapshot_fossil
def add(self, snapshot_collection: "SnapshotCollection") -> None:
self._snapshot_collections[snapshot_collection.location] = snapshot_collection

def update(self, snapshot_fossil: "SnapshotFossil") -> None:
snapshot_fossil_to_update = self.get(snapshot_fossil.location)
if snapshot_fossil_to_update is None:
snapshot_fossil_to_update = SnapshotFossil(
location=snapshot_fossil.location
def update(self, snapshot_collection: "SnapshotCollection") -> None:
snapshot_collection_to_update = self.get(snapshot_collection.location)
if snapshot_collection_to_update is None:
snapshot_collection_to_update = SnapshotCollection(
location=snapshot_collection.location
)
self.add(snapshot_fossil_to_update)
snapshot_fossil_to_update.merge(snapshot_fossil)
self.add(snapshot_collection_to_update)
snapshot_collection_to_update.merge(snapshot_collection)

def merge(self, snapshot_fossils: "SnapshotFossils") -> None:
for snapshot_fossil in snapshot_fossils:
self.update(snapshot_fossil)
def merge(self, snapshot_collections: "SnapshotCollections") -> None:
for snapshot_collection in snapshot_collections:
self.update(snapshot_collection)

def __iter__(self) -> Iterator["SnapshotFossil"]:
return iter(self._snapshot_fossils.values())
def __iter__(self) -> Iterator["SnapshotCollection"]:
return iter(self._snapshot_collections.values())

def __contains__(self, key: str) -> bool:
return key in self._snapshot_fossils
return key in self._snapshot_collections


@dataclass
Expand Down
27 changes: 14 additions & 13 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Set,
)

from syrupy.data import SnapshotFossil
from syrupy.data import SnapshotCollection
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import DataSerializer
Expand All @@ -21,6 +21,8 @@ class AmberSnapshotExtension(AbstractSyrupyExtension):
An amber snapshot file stores data in the following format:
"""

_file_extension = "ambr"

def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
"""
Returns the serialized form of 'data' to be compared
Expand All @@ -31,27 +33,23 @@ def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
) -> None:
snapshot_fossil_to_update = DataSerializer.read_file(snapshot_location)
snapshot_collection_to_update = DataSerializer.read_file(snapshot_location)
for snapshot_name in snapshot_names:
snapshot_fossil_to_update.remove(snapshot_name)
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_fossil_to_update.has_snapshots:
DataSerializer.write_file(snapshot_fossil_to_update)
if snapshot_collection_to_update.has_snapshots:
DataSerializer.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

@property
def _file_extension(self) -> str:
return "ambr"

def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil":
def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)

@staticmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
) -> "SnapshotFossil":
) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)

def _read_snapshot_data_from_location(
Expand All @@ -63,8 +61,11 @@ def _read_snapshot_data_from_location(
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None

def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None:
DataSerializer.write_file(snapshot_fossil, merge=True)
@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
DataSerializer.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "DataSerializer"]
22 changes: 12 additions & 10 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from syrupy.data import (
Snapshot,
SnapshotFossil,
SnapshotCollection,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,18 +70,20 @@ class DataSerializer:
_marker_crn: str = "\r\n"

@classmethod
def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None:
def write_file(
cls, snapshot_collection: "SnapshotCollection", merge: bool = False
) -> None:
"""
Writes the snapshot data into the snapshot file that can be read later.
"""
filepath = snapshot_fossil.location
filepath = snapshot_collection.location
if merge:
base_snapshot = cls.read_file(filepath)
base_snapshot.merge(snapshot_fossil)
snapshot_fossil = base_snapshot
base_snapshot.merge(snapshot_collection)
snapshot_collection = base_snapshot

with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
for snapshot in sorted(snapshot_fossil, key=lambda s: s.name):
for snapshot in sorted(snapshot_collection, key=lambda s: s.name):
snapshot_data = str(snapshot.data)
if snapshot_data is not None:
f.write(f"{cls._marker_name} {snapshot.name}\n")
Expand All @@ -90,15 +92,15 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> N
f.write(f"\n{cls._marker_divider}\n")

@classmethod
def read_file(cls, filepath: str) -> "SnapshotFossil":
def read_file(cls, filepath: str) -> "SnapshotCollection":
"""
Read the raw snapshot data (str) from the snapshot file into a dict
of snapshot name to raw data. This does not attempt any deserialization
of the snapshot data.
"""
name_marker_len = len(cls._marker_name)
indent_len = len(cls._indent)
snapshot_fossil = SnapshotFossil(location=filepath)
snapshot_collection = SnapshotCollection(location=filepath)
try:
with open(filepath, "r", encoding=TEXT_ENCODING, newline=None) as f:
test_name = None
Expand All @@ -112,7 +114,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil":
if line.startswith(cls._indent):
snapshot_data += line[indent_len:]
elif line.startswith(cls._marker_divider) and snapshot_data:
snapshot_fossil.add(
snapshot_collection.add(
Snapshot(
name=test_name,
data=snapshot_data.rstrip(os.linesep),
Expand All @@ -121,7 +123,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil":
except FileNotFoundError:
pass

return snapshot_fossil
return snapshot_collection

@classmethod
def serialize(
Expand Down
Loading

0 comments on commit 1038c30

Please sign in to comment.