Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support static type checking #59

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include LICENSE
include README.rst
include pytest_snapshot/py.typed

recursive-exclude * __pycache__
recursive-exclude * *.py[co]
9 changes: 5 additions & 4 deletions pytest_snapshot/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from pathlib import Path
from typing import List, Tuple

import pytest

Expand Down Expand Up @@ -44,7 +45,7 @@ def might_be_valid_filename(s: str) -> bool:
)


def simple_version_parse(version: str):
def simple_version_parse(version: str) -> Tuple[int, int, int]:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns a 3 tuple of the versions major, minor, and patch.
Raises a value error if the version string is unsupported.
Expand Down Expand Up @@ -75,7 +76,7 @@ def _pytest_expected_on_right() -> bool:
return pytest_version >= (5, 4, 0)


def flatten_dict(d: dict):
def flatten_dict(d: dict) -> List[Tuple[List, ...]]:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the flattened dict representation of the given dict.

Expand All @@ -96,7 +97,7 @@ def flatten_dict(d: dict):
return result


def _flatten_dict(obj, result, prefix):
def _flatten_dict(obj: dict, result: list, prefix: list) -> None:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
if type(obj) is dict:
for k, v in obj.items():
prefix.append(k)
Expand All @@ -106,7 +107,7 @@ def _flatten_dict(obj, result, prefix):
result.append((list(prefix), obj))


def flatten_filesystem_dict(d):
def flatten_filesystem_dict(d: dict) -> dict:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the flattened dict of a nested dictionary structure describing a filesystem.

Expand Down
24 changes: 13 additions & 11 deletions pytest_snapshot/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
from pathlib import Path
from typing import List, Union
from typing import Any, Callable, List, Tuple, Union

import pytest
import _pytest.python
Expand Down Expand Up @@ -36,7 +36,7 @@ def snapshot(request):
yield snapshot


def _assert_equal(value, snapshot) -> None:
def _assert_equal(value: Any, snapshot: Any) -> None:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
if _pytest_expected_on_right():
assert value == snapshot
else:
Expand Down Expand Up @@ -68,12 +68,12 @@ def _file_decode(data: bytes) -> str:


class Snapshot:
_snapshot_update = None # type: bool
_allow_snapshot_deletion = None # type: bool
_created_snapshots = None # type: List[Path]
_updated_snapshots = None # type: List[Path]
_snapshots_to_delete = None # type: List[Path]
_snapshot_dir = None # type: Path
_snapshot_update: bool
_allow_snapshot_deletion: bool
_created_snapshots: List[Path]
_updated_snapshots: List[Path]
_snapshots_to_delete: List[Path]
_snapshot_dir: Path

def __init__(self, snapshot_update: bool, allow_snapshot_deletion: bool, snapshot_dir: Path):
self._snapshot_update = snapshot_update
Expand Down Expand Up @@ -135,7 +135,9 @@ def _snapshot_path(self, snapshot_name: Union[str, Path]) -> Path:

return snapshot_path

def _get_compare_encode_decode(self, value: Union[str, bytes]):
def _get_compare_encode_decode(self, value: Union[str, bytes]) -> Tuple[
Callable[[Any, Any], None], Callable[..., bytes], Callable[..., str]
]:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns a 3-tuple of a compare function, an encoding function, and a decoding function.

Expand All @@ -151,7 +153,7 @@ def _get_compare_encode_decode(self, value: Union[str, bytes]):
else:
raise TypeError('value must be str or bytes')

def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]):
def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]) -> None:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
"""
Asserts that ``value`` equals the current value of the snapshot with the given ``snapshot_name``.

Expand Down Expand Up @@ -202,7 +204,7 @@ def assert_match(self, value: Union[str, bytes], snapshot_name: Union[str, Path]
"snapshot {} doesn't exist. (run pytest with --snapshot-update to create it)".format(
shorten_path(snapshot_path)))

def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]):
def assert_match_dir(self, dir_dict: dict, snapshot_dir_name: Union[str, Path]) -> None:
lonelyteapot marked this conversation as resolved.
Show resolved Hide resolved
"""
Asserts that the values in dir_dict equal the current values in the given snapshot directory.

Expand Down
Empty file added pytest_snapshot/py.typed
Empty file.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ classifiers =

[options]
packages = pytest_snapshot
zip_safe = False
include_package_data = True
python_requires = >=3.5
install_requires =
pytest >= 3.0.0
Expand Down