Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support static type checking #59

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 0 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ jobs:
os: macos-latest
# macos builds sometimes get stuck starting "python -m tox".
experimental: true
- python-version: 3.5
# Latest os version that supports Python 3.5
os: ubuntu-20.04
experimental: false
- python-version: 3.6
# Latest os version that supports Python 3.6
os: ubuntu-20.04
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include LICENSE
include README.rst
include pytest_snapshot/py.typed

recursive-exclude * __pycache__
recursive-exclude * *.py[co]
33 changes: 20 additions & 13 deletions pytest_snapshot/_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, TypeVar, Union, cast

import pytest

SIMPLE_VERSION_REGEX = re.compile(r'([0-9]+)\.([0-9]+)\.([0-9]+)')
ILLEGAL_FILENAME_CHARS = r'\/:*?"<>|'

_K = TypeVar("_K")
_V = TypeVar("_V")
_RecursiveDict = Dict[_K, Union["_RecursiveDict", _V]]


def shorten_path(path: Path) -> Path:
"""
Expand Down Expand Up @@ -44,7 +49,7 @@ def might_be_valid_filename(s: str) -> bool:
)


def simple_version_parse(version: str):
def simple_version_parse(version: str) -> Tuple[int, ...]:
"""
Returns a 3 tuple of the versions major, minor, and patch.
Raises a value error if the version string is unsupported.
Expand Down Expand Up @@ -75,7 +80,7 @@ def _pytest_expected_on_right() -> bool:
return pytest_version >= (5, 4, 0)


def flatten_dict(d: dict):
def flatten_dict(d: _RecursiveDict[_K, _V]) -> List[Tuple[List[_K], _V]]:
"""
Returns the flattened dict representation of the given dict.

Expand All @@ -91,22 +96,24 @@ def flatten_dict(d: dict):
[(['a'], 1), (['b', 'c'], 2)]
"""
assert type(d) is dict
result = []
_flatten_dict(d, result, [])
result: List[Tuple[List[_K], _V]] = []
_flatten_dict(d, result, []) # type: ignore[misc]
return result


def _flatten_dict(obj, result, prefix):
if type(obj) is dict:
for k, v in obj.items():
prefix.append(k)
_flatten_dict(v, result, prefix)
prefix.pop()
else:
result.append((list(prefix), obj))
def _flatten_dict(
obj: _RecursiveDict[_K, _V], result: List[Tuple[List[_K], _V]], prefix: List[_K]
) -> None:
for k, v in obj.items():
prefix.append(k)
if type(v) is dict:
_flatten_dict(cast(_RecursiveDict[_K, _V], v), result, prefix)
else:
result.append((list(prefix), cast(_V, v)))
prefix.pop()


def flatten_filesystem_dict(d):
def flatten_filesystem_dict(d: _RecursiveDict[str, _V]) -> Dict[str, _V]:
"""
Returns the flattened dict of a nested dictionary structure describing a filesystem.

Expand Down
70 changes: 46 additions & 24 deletions pytest_snapshot/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@
import os
import re
from pathlib import Path
from typing import Union
from typing import Any, AnyStr, Callable, Iterator, List, Optional, Tuple, Union

import pytest
import _pytest.python

from pytest_snapshot._utils import shorten_path, get_valid_filename, _pytest_expected_on_right, flatten_filesystem_dict
try:
from pytest import Parser as _Parser
except ImportError:
from _pytest.config.argparsing import Parser as _Parser

try:
from pytest import FixtureRequest as _FixtureRequest
except ImportError:
from _pytest.fixtures import FixtureRequest as _FixtureRequest

from pytest_snapshot._utils import shorten_path, get_valid_filename, _pytest_expected_on_right
from pytest_snapshot._utils import flatten_filesystem_dict, _RecursiveDict

PARAMETRIZED_TEST_REGEX = re.compile(r'^.*?\[(.*)]$')


def pytest_addoption(parser):
def pytest_addoption(parser: _Parser) -> None:
group = parser.getgroup('snapshot')
group.addoption(
'--snapshot-update',
Expand All @@ -27,7 +37,9 @@ def pytest_addoption(parser):


@pytest.fixture
def snapshot(request):
def snapshot(request: _FixtureRequest) -> Iterator["Snapshot"]:
# FIXME Properly handle different node type
assert isinstance(request.node, pytest.Function)
default_snapshot_dir = _get_default_snapshot_dir(request.node)

with Snapshot(request.config.option.snapshot_update,
Expand All @@ -36,7 +48,7 @@ def snapshot(request):
yield snapshot


def _assert_equal(value, snapshot) -> None:
def _assert_equal(value: AnyStr, snapshot: AnyStr) -> None:
if _pytest_expected_on_right():
assert value == snapshot
else:
Expand Down Expand Up @@ -68,12 +80,12 @@ def _file_decode(data: bytes) -> str:


class Snapshot:
_snapshot_update = None # type: bool
_allow_snapshot_deletion = None # type: bool
_created_snapshots = None # type: List[Path]
_updated_snapshots = None # type: List[Path]
_snapshots_to_delete = None # type: List[Path]
_snapshot_dir = None # type: Path
_snapshot_update: bool
_allow_snapshot_deletion: bool
_created_snapshots: List[Path]
_updated_snapshots: List[Path]
_snapshots_to_delete: List[Path]
_snapshot_dir: Path

def __init__(self, snapshot_update: bool, allow_snapshot_deletion: bool, snapshot_dir: Path):
self._snapshot_update = snapshot_update
Expand All @@ -83,10 +95,10 @@ def __init__(self, snapshot_update: bool, allow_snapshot_deletion: bool, snapsho
self._updated_snapshots = []
self._snapshots_to_delete = []

def __enter__(self):
def __enter__(self) -> "Snapshot":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, *_: Any) -> None:
if self._created_snapshots or self._updated_snapshots or self._snapshots_to_delete:
message_lines = ['Snapshot directory was modified: {}'.format(shorten_path(self.snapshot_dir)),
' (verify that the changes are expected before committing them to version control)']
Expand All @@ -112,14 +124,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pytest.fail('\n'.join(message_lines), pytrace=False)

@property
def snapshot_dir(self):
def snapshot_dir(self) -> Path:
return self._snapshot_dir

@snapshot_dir.setter
def snapshot_dir(self, value):
def snapshot_dir(self, value: Union[str, 'os.PathLike[str]']) -> None:
self._snapshot_dir = Path(value).absolute()

def _snapshot_path(self, snapshot_name: Union[str, Path]) -> Path:
def _snapshot_path(self, snapshot_name: Union[str, 'os.PathLike[str]']) -> Path:
"""
Returns the absolute path to the given snapshot.
"""
Expand All @@ -135,7 +147,11 @@ def _snapshot_path(self, snapshot_name: Union[str, Path]) -> Path:

return snapshot_path

def _get_compare_encode_decode(self, value: Union[str, bytes]):
def _get_compare_encode_decode(self, value: AnyStr) -> Tuple[
Callable[[AnyStr, AnyStr], None],
Callable[[AnyStr], bytes],
Callable[[bytes], AnyStr]
]:
"""
Returns a 3-tuple of a compare function, an encoding function, and a decoding function.

Expand All @@ -147,11 +163,12 @@ def _get_compare_encode_decode(self, value: Union[str, bytes]):
if isinstance(value, str):
return _assert_equal, _file_encode, _file_decode
elif isinstance(value, bytes):
return _assert_equal, lambda x: x, lambda x: x
noop: Callable[[bytes], bytes] = lambda x: x
return _assert_equal, noop, noop
else:
raise TypeError('value must be str or bytes')

def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]):
def assert_match(self, value: AnyStr, snapshot_name: Union[str, 'os.PathLike[str]']) -> None:
"""
Asserts that ``value`` equals the current value of the snapshot with the given ``snapshot_name``.

Expand Down Expand Up @@ -185,6 +202,7 @@ def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]
else:
if encoded_expected_value is not None:
expected_value = decode(encoded_expected_value)
snapshot_diff_msg: Optional[str]
try:
compare(value, expected_value)
except AssertionError as e:
Expand All @@ -202,7 +220,11 @@ def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]
"snapshot {} doesn't exist. (run pytest with --snapshot-update to create it)".format(
shorten_path(snapshot_path)))

def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]):
def assert_match_dir(
self,
dir_dict: _RecursiveDict[str, Union[bytes, str]],
snapshot_dir_name: Union[str, 'os.PathLike[str]']
) -> None:
"""
Asserts that the values in dir_dict equal the current values in the given snapshot directory.

Expand All @@ -214,7 +236,7 @@ def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]):
raise TypeError('dir_dict must be a dictionary')

snapshot_dir_path = self._snapshot_path(snapshot_dir_name)
values_by_filename = flatten_filesystem_dict(dir_dict)
values_by_filename = flatten_filesystem_dict(dir_dict) # type: ignore[misc]
if snapshot_dir_path.is_dir():
existing_names = {p.relative_to(snapshot_dir_path).as_posix()
for p in snapshot_dir_path.rglob('*') if p.is_file()}
Expand Down Expand Up @@ -242,10 +264,10 @@ def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]):

# Call assert_match to add, update, or assert equality for all snapshot files in the directory.
for name, value in values_by_filename.items():
self.assert_match(value, snapshot_dir_path.joinpath(name))
self.assert_match(value, snapshot_dir_path.joinpath(name)) # pyright: ignore


def _get_default_snapshot_dir(node: _pytest.python.Function) -> Path:
def _get_default_snapshot_dir(node: pytest.Function) -> Path:
"""
Returns the default snapshot directory for the pytest test.
"""
Expand Down
Empty file added pytest_snapshot/py.typed
Empty file.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ classifiers =

[options]
packages = pytest_snapshot
zip_safe = False
include_package_data = True
python_requires = >=3.5
install_requires =
pytest >= 3.0.0
Expand Down
17 changes: 14 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
envlist =
# Pytest <6.2.5 not supported on Python >=3.10
py{36,37,38,39}-pytest{3,4,5}-coverage
py{35,36,37,38,39,310,311,312,3}-pytest{6,}-coverage
py{36,37,38,39,310,311,312,3}-pytest{6,}-coverage
# Coverage is slow in pypy
pypy3-pytest{6,}
flake8
pyright
mypy

[testenv]
deps =
Expand All @@ -32,18 +34,27 @@ skip_install = true
deps = flake8
commands = flake8 pytest_snapshot setup.py tests

[testenv:pyright]
deps = pyright
commands = pyright --verifytypes pytest_snapshot --ignoreexternal

[testenv:mypy]
deps =
mypy
py
commands = mypy -p pytest_snapshot

[flake8]
max-line-length = 120

[gh-actions]
python =
3.5: py35
3.6: py36
3.7: py37
3.8: py38
3.9: py39
3.10: py310
3.11: py311
3.12: py312, flake8
3.12: py312, flake8, mypy, pyright
3: py3
pypy-3.10: pypy3