Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add include option to snapshots, similar to exclude #797

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,32 @@ If you want to limit what properties are serialized at a class type level you co

```py
def limit_foo_attrs(prop, path):
allowed_foo_attrs = {"only", "serialize", "these", "attrs"}
return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs
allowed_foo_attrs = {"do", "not", "serialize", "these", "attrs"}
return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs

def test_bar(snapshot):
actual = Foo(...)
assert actual == snapshot(exclude=limit_foo_attrs)
```

**B**. Or override the `__dir__` implementation to control the attribute list.
**B**. Provide a filter function to the snapshot [include](#include) configuration option.

```py
def limit_foo_attrs(prop, path):
allowed_foo_attrs = {"only", "serialize", "these", "attrs"}
return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs

def test_bar(snapshot):
actual = Foo(...)
assert actual == snapshot(include=limit_foo_attrs)
```

**C**. Or override the `__dir__` implementation to control the attribute list.

```py
class Foo:
def __dir__(self):
return ["only", "serialize", "these", "attrs"]
def __dir__(self):
return ["only", "serialize", "these", "attrs"]

def test_bar(snapshot):
actual = Foo(...)
Expand Down Expand Up @@ -211,7 +223,7 @@ Only runs replacement for objects at a matching path where the value of the mapp
This allows you to filter out object properties from the serialized snapshot.

The exclude parameter takes a filter function that accepts two keyword arguments.
It should return `true` or `false` if the property should be excluded or included respectively.
It should return `true` if the property should be excluded, or `false` if the property should be included.

| Argument | Description |
| -------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
Expand Down Expand Up @@ -278,6 +290,29 @@ def test_bar(snapshot):
# ---
```

#### `include`

This allows you filter an object's properties to a subset using a predicate. This is the opposite of [exclude](#exclude). All the same property filters supporterd by [exclude](#exclude) are supported for `include`.

The include parameter takes a filter function that accepts two keyword arguments.
It should return `true` if the property should be include, or `false` if the property should not be included.

| Argument | Description |
| -------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
| `prop` | Current property on the object, could be any hashable value that can be used to retrieve a value e.g. `1`, `"prop_str"`, `SomeHashableObject` |
| `path` | Ordered path traversed to the current value e.g. `(("a", dict), ("b", dict))` from `{ "a": { "b": { "c": 1 } } }`}

Note that `include` has some caveats which make it a bit more difficult to use than `exclude`. Both `include` and `exclude` are evaluated for each key of an object before traversing down nested paths. This means if you want to include a nested path, you must include all parents of the nested path, otherwise the nested child will never be reached to be evaluated against the include predicate. For example:

```py
obj = {
"nested": { "key": True }
}
assert obj == snapshot(include=paths("nested", "nested.key"))
```

The extra "nested" is required, otherwise the nested dictionary will never be searched -- it'd get pruned too early.

#### `extension_class`

This is a way to modify how the snapshot matches and serializes your data in a single assertion.
Expand Down
9 changes: 8 additions & 1 deletion src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class SnapshotAssertion:
init=False,
default=None,
)
_include: Optional["PropertyFilter"] = field(
init=False,
default=None,
)
_custom_index: Optional[str] = field(
init=False,
default=None,
Expand Down Expand Up @@ -180,7 +184,7 @@ def assert_match(self, data: "SerializableData") -> None:

def _serialize(self, data: "SerializableData") -> "SerializedData":
return self.extension.serialize(
data, exclude=self._exclude, matcher=self.__matcher
data, exclude=self._exclude, include=self._include, matcher=self.__matcher
)

def get_assert_diff(self) -> List[str]:
Expand Down Expand Up @@ -233,6 +237,7 @@ def __call__(
*,
diff: Optional["SnapshotIndex"] = None,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
Expand All @@ -242,6 +247,8 @@ def __call__(
"""
if exclude:
self.__with_prop("_exclude", exclude)
if include:
self.__with_prop("_include", include)
if extension_class:
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
Expand Down
17 changes: 14 additions & 3 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> str:
"""
Expand All @@ -211,7 +212,9 @@ def serialize(
same new line control characters. Example snapshots generated on windows os
should not break when running the tests on a unix based system and vice versa.
"""
serialized = cls._serialize(data, exclude=exclude, matcher=matcher)
serialized = cls._serialize(
data, exclude=exclude, include=include, matcher=matcher
)
return serialized.replace("\r\n", "\n").replace("\r", "\n")

@classmethod
Expand All @@ -221,6 +224,7 @@ def _serialize(
*,
depth: int = 0,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
path: "PropertyPath" = (),
visited: Optional[Set[Any]] = None,
Expand All @@ -235,6 +239,7 @@ def _serialize(
"data": data,
"depth": depth,
"exclude": exclude,
"include": include,
"matcher": matcher,
"path": path,
"visited": {*visited, data_id},
Expand Down Expand Up @@ -400,6 +405,7 @@ def serialize_custom_iterable(
close_paren: Optional[str] = None,
depth: int = 0,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
path: "PropertyPath" = (),
separator: Optional[str] = None,
serialize_key: bool = False,
Expand All @@ -414,7 +420,8 @@ def serialize_custom_iterable(
key_values = (
(key, get_value(data, key))
for key in keys
if not exclude or not exclude(prop=key, path=path)
if (not exclude or not exclude(prop=key, path=path))
and (not include or include(prop=key, path=path))
)
entries = (
entry
Expand All @@ -433,7 +440,11 @@ def key_str(key: "PropertyName") -> str:

def value_str(key: "PropertyName", value: "SerializableData") -> str:
serialized = cls._serialize(
data=value, exclude=exclude, path=(*path, (key, type(value))), **kwargs
data=value,
exclude=exclude,
include=include,
path=(*path, (key, type(value))),
**kwargs,
)
return serialized if separator is None else serialized.lstrip(cls._indent)

Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
"""
Expand Down
14 changes: 13 additions & 1 deletion src/syrupy/extensions/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _filter(
depth: int = 0,
path: "PropertyPath",
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
visited: Optional[Set[Any]] = None,
) -> "SerializableData":
Expand All @@ -80,13 +81,16 @@ def _filter(
value = data[key]
if exclude and exclude(prop=key, path=path):
continue
if include and not include(prop=key, path=path):
continue
if not isinstance(key, (str,)):
continue
filtered_dct[key] = cls._filter(
data=value,
depth=depth + 1,
path=(*path, (key, type(value))),
exclude=exclude,
include=include,
matcher=matcher,
visited={*visited, data_id},
)
Expand All @@ -101,6 +105,7 @@ def _filter(
depth=depth + 1,
path=(*path, (key, type(value))),
exclude=exclude,
include=include,
matcher=matcher,
visited={*visited, data_id},
)
Expand All @@ -118,6 +123,7 @@ def _filter(
depth=depth + 1,
path=(*path, (key, type(value))),
exclude=exclude,
include=include,
matcher=matcher,
visited={*visited, data_id},
)
Expand All @@ -137,9 +143,15 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
data=data,
depth=0,
path=(),
exclude=exclude,
include=include,
matcher=matcher,
)
return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
1 change: 1 addition & 0 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
return self.get_supported_dataclass()(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,18 @@
}),
})
# ---
# name: test_only_includes_expected_props
dict({
'date': 'utc',
0: 'some value',
})
# ---
# name: test_only_includes_expected_props.1
dict({
'date': 'utc',
'nested': dict({
'id': 4,
}),
0: 'some value',
})
# ---
12 changes: 12 additions & 0 deletions tests/syrupy/extensions/amber/test_amber_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def test_filters_expected_props(snapshot):
assert actual == snapshot(exclude=props("0", "date", "id"))


def test_only_includes_expected_props(snapshot):
actual = {
0: "some value",
"date": "utc",
"nested": {"id": 4, "other": "value"},
"list": [1, 2],
}
# Note that "id" won't get included because "nested" (its parent) is not included.
assert actual == snapshot(include=props("0", "date", "id"))
assert actual == snapshot(include=paths("0", "date", "nested", "nested.id"))


@pytest.mark.parametrize(
"predicate", [paths("exclude_me", "nested.exclude_me"), props("exclude_me")]
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"foo": "__SHOULD_BE_REMOVED_FROM_JSON__",
"id": 123456789
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"foo": "__SHOULD_BE_REMOVED_FROM_JSON__",
"id": 123456789
}
13 changes: 13 additions & 0 deletions tests/syrupy/extensions/json/test_json_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ def test_exclude_simple(snapshot_json):
assert snapshot_json(exclude=paths("id", "foo")) == content


def test_include_simple(snapshot_json):
content = {
"id": 123456789,
"foo": "__SHOULD_BE_REMOVED_FROM_JSON__",
"I'm": "still alive",
"nested": {
"foo": "is still alive",
},
}
assert snapshot_json(include=props("id", "foo")) == content
assert snapshot_json(include=paths("id", "foo")) == content


def test_exclude_nested(snapshot_json):
content = {
"a": "b",
Expand Down