Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for custom snapshot names, close #555 #563

Merged
merged 13 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,18 @@ Syrupy comes with a few built-in preset configurations for you to choose from. Y
- **`PNGSnapshotExtension`**: An extension of single file, this should be used to produce `.png` files from a byte string.
- **`SVGSnapshotExtension`**: Another extension of single file. This produces `.svg` files from an svg string.

#### `name`

By default, if you make multiple snapshot assertions within a single test case, an auto-increment identifier will be used to index the snapshots. You can override this behaviour by specifying a custom snapshot name to use in place of the auto-increment number.

```py
def test_case(snapshot):
assert "actual" == snapshot(name="case_a")
assert "other" == snapshot(name="case_b")
```

> _Warning_: If you use a custom name, you must make sure the name is not re-used within a test case.

### Advanced Usage

By overriding the provided [`AbstractSnapshotExtension`](https://github.com/tophat/syrupy/tree/master/src/syrupy/extensions/base.py) you can implement varied custom behaviours.
Expand Down
24 changes: 18 additions & 6 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
List,
Optional,
Type,
Union,
)

import attr
Expand Down Expand Up @@ -54,6 +55,7 @@ class SnapshotAssertion:
_exclude: Optional["PropertyFilter"] = attr.ib(
init=False, default=None, kw_only=True
)
_custom_index: Optional[str] = attr.ib(init=False, default=None, kw_only=True)
_extension: Optional["AbstractSyrupyExtension"] = attr.ib(
init=False, default=None, kw_only=True
)
Expand Down Expand Up @@ -90,6 +92,12 @@ def num_executions(self) -> int:
def executions(self) -> Dict[int, AssertionResult]:
return self._execution_results

@property
def index(self) -> Union[str, int]:
if self._custom_index:
return self._custom_index
return self.num_executions

def use_extension(
self, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None
) -> "SnapshotAssertion":
Expand Down Expand Up @@ -149,6 +157,7 @@ def __call__(
exclude: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional[str] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -159,6 +168,8 @@ def __call__(
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
return self

def __dir__(self) -> List[str]:
Expand All @@ -168,21 +179,22 @@ def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(index=self.num_executions)
snapshot_name = self.extension.get_snapshot_name(index=self.num_executions)
snapshot_location = self.extension.get_location(index=self.index)
snapshot_name = self.extension.get_snapshot_name(index=self.index)
snapshot_data: Optional["SerializedData"] = None
serialized_data: Optional["SerializedData"] = None
matches = False
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data(index=self.num_executions)
snapshot_data = self._recall_data()
serialized_data = self._serialize(data)
matches = snapshot_data is not None and serialized_data == snapshot_data
assertion_success = matches
if not matches and self._update_snapshots:
self.extension.write_snapshot(
data=serialized_data, index=self.num_executions
data=serialized_data,
index=self.index,
)
assertion_success = True
return assertion_success
Expand Down Expand Up @@ -212,8 +224,8 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self, index: int) -> Optional["SerializableData"]:
def _recall_data(self) -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(index=index)
return self.extension.read_snapshot(index=self.index)
except SnapshotDoesNotExist:
return None
29 changes: 18 additions & 11 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
List,
Optional,
Set,
Union,
)

from syrupy.constants import (
Expand Down Expand Up @@ -73,12 +74,16 @@ class SnapshotFossilizer(ABC):
def test_location(self) -> "PyTestLocation":
raise NotImplementedError

def get_snapshot_name(self, *, index: int = 0) -> str:
def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = f".{index}" if index > 0 else ""
index_suffix = ""
if isinstance(index, (str,)):
index_suffix = f"[{index}]"
elif index:
index_suffix = f".{index}"
return f"{self.test_location.snapshot_name}{index_suffix}"

def get_location(self, *, index: int) -> str:
def get_location(self, *, index: Union[str, int]) -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
Expand All @@ -105,7 +110,7 @@ def discover_snapshots(self) -> "SnapshotFossils":

return discovered

def read_snapshot(self, *, index: int) -> "SerializedData":
def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
"""
Utility method for reading the contents of a snapshot assertion.
Will call `_pre_read`, then perform `read` and finally `post_read`,
Expand All @@ -127,7 +132,7 @@ def read_snapshot(self, *, index: int) -> "SerializedData":
finally:
self._post_read(index=index)

def write_snapshot(self, *, data: "SerializedData", index: int) -> None:
def write_snapshot(self, *, data: "SerializedData", index: Union[str, int]) -> None:
"""
Utility method for writing the contents of a snapshot assertion.
Will call `_pre_write`, then perform `write` and finally `_post_write`.
Expand Down Expand Up @@ -173,16 +178,18 @@ def delete_snapshots(
"""
raise NotImplementedError

def _pre_read(self, *, index: int = 0) -> None:
def _pre_read(self, *, index: Union[str, int] = 0) -> None:
pass

def _post_read(self, *, index: int = 0) -> None:
def _post_read(self, *, index: Union[str, int] = 0) -> None:
pass

def _pre_write(self, *, data: "SerializedData", index: int = 0) -> None:
def _pre_write(self, *, data: "SerializedData", index: Union[str, int] = 0) -> None:
self.__ensure_snapshot_dir(index=index)

def _post_write(self, *, data: "SerializedData", index: int = 0) -> None:
def _post_write(
self, *, data: "SerializedData", index: Union[str, int] = 0
) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -218,11 +225,11 @@ def _dirname(self) -> str:
def _file_extension(self) -> str:
raise NotImplementedError

def _get_file_basename(self, *, index: int) -> str:
def _get_file_basename(self, *, index: Union[str, int]) -> str:
"""Returns file basename without extension. Used to create full filepath."""
return self.test_location.filename

def __ensure_snapshot_dir(self, *, index: int) -> None:
def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
Expand Down
5 changes: 3 additions & 2 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TYPE_CHECKING,
Optional,
Set,
Union,
)
from unicodedata import category

Expand Down Expand Up @@ -33,7 +34,7 @@ def serialize(
) -> "SerializedData":
return bytes(data)

def get_snapshot_name(self, *, index: int = 0) -> str:
def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
return self.__clean_filename(
super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index)
)
Expand All @@ -47,7 +48,7 @@ def delete_snapshots(
def _file_extension(self) -> str:
return "raw"

def _get_file_basename(self, *, index: int) -> str:
def _get_file_basename(self, *, index: Union[str, int]) -> str:
return self.get_snapshot_name(index=index)

@property
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# name: test_snapshot_custom_snapshot_name_suffix[test_is_amazing]
'Syrupy is amazing!'
---
# name: test_snapshot_custom_snapshot_name_suffix[test_is_awesome]
'Syrupy is awesome!'
---
25 changes: 25 additions & 0 deletions tests/examples/test_custom_image_name_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import base64

import pytest

from syrupy.extensions.image import PNGImageSnapshotExtension


@pytest.fixture
def snapshot(snapshot):
return snapshot.use_extension(PNGImageSnapshotExtension)


def test_png_image_with_custom_name_suffix(snapshot):
reddish_square = base64.b64decode(
b"iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAIAAAAmkwkpAAAAIUlEQVQIHTXB"
b"MQEAAAABQUYtvpD+dUzu3KBzg84NOjfoBjmmAd3WpSsrAAAAAElFTkSuQmCC"
)

blueish_square = base64.b64decode(
b"iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAIAAAAmkwkpAAAAIUlEQVQIHTXB"
b"MQEAAAABQUYtvpD+dUzuTKozqc6kOpPqBjg+Ad2g/BLMAAAAAElFTkSuQmCC"
)

assert blueish_square == snapshot(name="blueish")
assert reddish_square == snapshot(name="reddish")
4 changes: 3 additions & 1 deletion tests/examples/test_custom_snapshot_name.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
Example: Custom Snapshot Name
"""
from typing import Union

import pytest

from syrupy.extensions.amber import AmberSnapshotExtension


class CanadianNameExtension(AmberSnapshotExtension):
def get_snapshot_name(self, *, index: int = 0) -> str:
def get_snapshot_name(self, *, index: Union[str, int]) -> str:
original_name = super(CanadianNameExtension, self).get_snapshot_name(
index=index
)
Expand Down
3 changes: 3 additions & 0 deletions tests/examples/test_custom_snapshot_name_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test_snapshot_custom_snapshot_name_suffix(snapshot):
assert "Syrupy is amazing!" == snapshot(name="test_is_amazing")
assert "Syrupy is awesome!" == snapshot(name="test_is_awesome")
63 changes: 63 additions & 0 deletions tests/integration/test_snapshot_option_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest


@pytest.fixture
def testcases():
return {
"base": (
"""
def test_a(snapshot):
assert snapshot(name="xyz") == "case 1"
assert snapshot(name="zyx") == "case 2"
"""
),
"modified": (
"""
def test_a(snapshot):
assert snapshot(name="xyz") == "case 1"
assert snapshot(name="zyx") == "case ??"
"""
),
}


@pytest.fixture
def run_testcases(testdir, testcases):
testdir.makepyfile(test_1=testcases["base"])
result = testdir.runpytest(
"-v",
"--snapshot-update",
)
result.stdout.re_match_lines((r"2 snapshots generated\."))
return testdir, testcases


def test_run_all(run_testcases):
testdir, testcases = run_testcases
result = testdir.runpytest(
"-v",
)
result.stdout.re_match_lines("2 snapshots passed")
assert result.ret == 0


def test_failure(run_testcases):
testdir, testcases = run_testcases
testdir.makepyfile(test_1=testcases["modified"])
result = testdir.runpytest(
"-v",
)
result.stdout.re_match_lines("1 snapshot failed. 1 snapshot passed.")
assert result.ret == 1


def test_update(run_testcases):
testdir, testcases = run_testcases
testdir.makepyfile(test_1=testcases["modified"])
result = testdir.runpytest(
"-v",
"--snapshot-update",
)
assert "Can not relate snapshot name" not in str(result.stdout)
result.stdout.re_match_lines("1 snapshot passed. 1 snapshot updated.")
assert result.ret == 0
4 changes: 2 additions & 2 deletions tests/integration/test_snapshot_use_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def _file_extension(self):
def serialize(self, data, **kwargs):
return str(data)

def get_snapshot_name(self, *, index = 0):
def get_snapshot_name(self, *, index):
testname = self._test_location.testname[::-1]
return f"{testname}.{index}"

def _get_file_basename(self, *, index = 0):
def _get_file_basename(self, *, index):
return self.test_location.filename[::-1]

@pytest.fixture
Expand Down