Skip to content

Commit

Permalink
fix: defer snapshot writes until end of session (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnu authored and Noah Negin-Ulster committed Dec 30, 2022
1 parent 4a9695d commit 3748df6
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 42 deletions.
7 changes: 5 additions & 2 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def _assert(self, data: "SerializableData") -> bool:
)
assertion_success = matches
if not matches and self.update_snapshots:
self.extension.write_snapshot(
self.session.queue_snapshot_write(
extension=self.extension,
data=serialized_data,
index=self.index,
)
Expand Down Expand Up @@ -297,6 +298,8 @@ def _post_assert(self) -> None:

def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(index=index)
return self.extension.read_snapshot(
index=index, session_id=str(id(self.session))
)
except SnapshotDoesNotExist:
return None
18 changes: 13 additions & 5 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -46,16 +47,23 @@ def _file_extension(self) -> str:
def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil":
return DataSerializer.read_file(snapshot_location)

@lru_cache()
def __cacheable_read_snapshot(
self, snapshot_location: str, cache_key: str
) -> "SnapshotFossil":
return DataSerializer.read_file(snapshot_location)

def _read_snapshot_data_from_location(
self, snapshot_location: str, snapshot_name: str
self, snapshot_location: str, snapshot_name: str, session_id: str
) -> Optional["SerializableData"]:
snapshot = self._read_snapshot_fossil(snapshot_location).get(snapshot_name)
snapshots = self.__cacheable_read_snapshot(
snapshot_location=snapshot_location, cache_key=session_id
)
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None

def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None:
snapshot_fossil_to_update = DataSerializer.read_file(snapshot_fossil.location)
snapshot_fossil_to_update.merge(snapshot_fossil)
DataSerializer.write_file(snapshot_fossil_to_update)
DataSerializer.write_file(snapshot_fossil, merge=True)


__all__ = ["AmberSnapshotExtension", "DataSerializer"]
11 changes: 7 additions & 4 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import os
from types import (
GeneratorType,
Expand Down Expand Up @@ -71,11 +70,16 @@ class DataSerializer:
_marker_crn: str = "\r\n"

@classmethod
def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None:
def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None:
"""
Writes the snapshot data into the snapshot file that be read later.
Writes the snapshot data into the snapshot file that can be read later.
"""
filepath = snapshot_fossil.location
if merge:
base_snapshot = cls.read_file(filepath)
base_snapshot.merge(snapshot_fossil)
snapshot_fossil = base_snapshot

with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
for snapshot in sorted(snapshot_fossil, key=lambda s: s.name):
snapshot_data = str(snapshot.data)
Expand All @@ -86,7 +90,6 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None:
f.write(f"\n{cls._marker_divider}\n")

@classmethod
@functools.lru_cache()
def read_file(cls, filepath: str) -> "SnapshotFossil":
"""
Read the raw snapshot data (str) from the snapshot file into a dict
Expand Down
100 changes: 70 additions & 30 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
ABC,
abstractmethod,
)
from collections import defaultdict
from difflib import ndiff
from gettext import gettext
from itertools import zip_longest
from pathlib import Path
from typing import (
TYPE_CHECKING,
Callable,
DefaultDict,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
)

from syrupy.constants import (
Expand Down Expand Up @@ -115,7 +118,9 @@ def discover_snapshots(self) -> "SnapshotFossils":

return discovered

def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
def read_snapshot(
self, *, index: "SnapshotIndex", session_id: str
) -> "SerializedData":
"""
Utility method for reading the contents of a snapshot assertion.
Will call `_pre_read`, then perform `read` and finally `post_read`,
Expand All @@ -129,7 +134,9 @@ def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
snapshot_location = self.get_location(index=index)
snapshot_name = self.get_snapshot_name(index=index)
snapshot_data = self._read_snapshot_data_from_location(
snapshot_location=snapshot_location, snapshot_name=snapshot_name
snapshot_location=snapshot_location,
snapshot_name=snapshot_name,
session_id=session_id,
)
if snapshot_data is None:
raise SnapshotDoesNotExist()
Expand All @@ -145,33 +152,66 @@ def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> N
This method is _final_, do not override. You can override
`_write_snapshot_fossil` in a subclass to change behaviour.
"""
self._pre_write(data=data, index=index)
snapshot_location = self.get_location(index=index)
if not self.test_location.matches_snapshot_location(snapshot_location):
warning_msg = gettext(
"{line_end}Can not relate snapshot location '{}' to the test location."
"{line_end}Consider adding '{}' to the generated location."
).format(
snapshot_location,
self.test_location.filename,
line_end="\n",
)
warnings.warn(warning_msg)
snapshot_name = self.get_snapshot_name(index=index)
if not self.test_location.matches_snapshot_name(snapshot_name):
warning_msg = gettext(
"{line_end}Can not relate snapshot name '{}' to the test location."
"{line_end}Consider adding '{}' to the generated name."
).format(
snapshot_name,
self.test_location.testname,
line_end="\n",
)
warnings.warn(warning_msg)
snapshot_fossil = SnapshotFossil(location=snapshot_location)
snapshot_fossil.add(Snapshot(name=snapshot_name, data=data))
self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil)
self._post_write(data=data, index=index)
self.write_snapshot_batch(snapshots=[(data, index)])

def write_snapshot_batch(
self, *, snapshots: List[Tuple["SerializedData", "SnapshotIndex"]]
) -> None:
"""
Utility method for writing the contents of multiple snapshot assertions.
Will call `_pre_write` per snapshot, then perform `write` per snapshot
and finally `_post_write`.
This method is _final_, do not override. You can override
`_write_snapshot_fossil` in a subclass to change behaviour.
"""
# First we group by location since it'll let us batch by file on disk.
# Not as useful for single file snapshots, but useful for the standard
# Amber extension.
locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list)
for data, index in snapshots:
location = self.get_location(index=index)
snapshot_name = self.get_snapshot_name(index=index)
locations[location].append(Snapshot(name=snapshot_name, data=data))

# Is there a better place to do the pre-writes?
# Or can we remove the pre-write concept altogether?
self._pre_write(data=data, index=index)

for location, location_snapshots in locations.items():
snapshot_fossil = SnapshotFossil(location=location)

if not self.test_location.matches_snapshot_location(location):
warning_msg = gettext(
"{line_end}Can not relate snapshot location '{}' "
"to the test location.{line_end}"
"Consider adding '{}' to the generated location."
).format(
location,
self.test_location.filename,
line_end="\n",
)
warnings.warn(warning_msg)

for snapshot in location_snapshots:
snapshot_fossil.add(snapshot)

if not self.test_location.matches_snapshot_name(snapshot.name):
warning_msg = gettext(
"{line_end}Can not relate snapshot name '{}' "
"to the test location.{line_end}"
"Consider adding '{}' to the generated name."
).format(
snapshot.name,
self.test_location.testname,
line_end="\n",
)
warnings.warn(warning_msg)

self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil)

for data, index in snapshots:
self._post_write(data=data, index=index)

@abstractmethod
def delete_snapshots(
Expand Down Expand Up @@ -206,7 +246,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil":

@abstractmethod
def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str
self, *, snapshot_location: str, snapshot_name: str, session_id: str
) -> Optional["SerializedData"]:
"""
Get only the snapshot data from location for assertion
Expand Down
2 changes: 1 addition & 1 deletion src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil":
return snapshot_fossil

def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str
self, *, snapshot_location: str, snapshot_name: str, session_id: str
) -> Optional["SerializableData"]:
try:
with open(
Expand Down
26 changes: 26 additions & 0 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
List,
Optional,
Set,
Tuple,
)

import pytest
Expand All @@ -24,6 +25,10 @@
is_xdist_controller,
is_xdist_worker,
)
from .types import (
SerializedData,
SnapshotIndex,
)

if TYPE_CHECKING:
from .assertion import SnapshotAssertion
Expand All @@ -47,6 +52,26 @@ class SnapshotSession:
default_factory=lambda: defaultdict(set)
)

_queued_snapshot_writes: Dict[
"AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]]
] = field(default_factory=dict)

def queue_snapshot_write(
self,
extension: "AbstractSyrupyExtension",
data: "SerializedData",
index: "SnapshotIndex",
) -> None:
queue = self._queued_snapshot_writes.get(extension, [])
queue.append((data, index))
self._queued_snapshot_writes[extension] = queue

def flush_snapshot_write_queue(self) -> None:
for extension, queued_write in self._queued_snapshot_writes.items():
if queued_write:
extension.write_snapshot_batch(snapshots=queued_write)
self._queued_snapshot_writes = {}

@property
def update_snapshots(self) -> bool:
return bool(self.pytest_session.config.option.update_snapshots)
Expand Down Expand Up @@ -76,6 +101,7 @@ def ran_item(self, nodeid: str) -> None:

def finish(self) -> int:
exitstatus = 0
self.flush_snapshot_write_queue()
self.report = SnapshotReport(
base_dir=self.pytest_session.config.rootdir,
collected_items=self._collected_items,
Expand Down

0 comments on commit 3748df6

Please sign in to comment.