Skip to content

Commit

Permalink
refactor: Strict typings
Browse files Browse the repository at this point in the history
  • Loading branch information
tony committed Sep 8, 2022
1 parent b831f10 commit ffe5aaa
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 100 deletions.
173 changes: 129 additions & 44 deletions src/pytest_sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,35 @@
* TODO
** CLEANUP: use the sphinx directive parser from the sphinx project
"""

import doctest
import enum
import re
import sys
import textwrap
import traceback
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import _pytest.doctest
import pytest
from _pytest.config import Config
from _pytest.doctest import DoctestItem
from _pytest.doctest import _is_mocked
from _pytest.doctest import _patch_unwrap_mock_aware
from _pytest.main import Session
from _pytest.pathlib import import_path
from _pytest.python import Package

if TYPE_CHECKING:
import pdb
from doctest import _Out


class SphinxDoctestDirectives(enum.Enum):
Expand All @@ -43,21 +57,31 @@ class SphinxDoctestDirectives(enum.Enum):
)


def pytest_collect_file(path, parent):
def pytest_collect_file(
file_path: Path, parent: Union[Session, Package]
) -> Optional[Union["SphinxDoctestModule", "SphinxDoctestTextfile"]]:
config = parent.config
if path.ext == ".py":
if file_path.suffix == ".py":
if config.option.doctestmodules:
return SphinxDoctestModule.from_parent(parent, path=Path(path.strpath))
elif _is_doctest(config, path, parent):
return SphinxDoctestTextfile.from_parent(parent, path=Path(path.strpath))
mod: Union[
"SphinxDoctestModule", "SphinxDoctestTextfile"
] = SphinxDoctestModule.from_parent(parent, path=file_path)
return mod
elif _is_doctest(config, file_path, parent):
return SphinxDoctestTextfile.from_parent(parent, path=file_path) # type: ignore
return None


GlobDict = Dict[str, Any]

def _is_doctest(config, path, parent):
if path.ext in (".txt", ".rst") and parent.session.isinitpath(path):

def _is_doctest(config: Config, path: Path, parent: Union[Session, Package]) -> bool:
if path.suffix in (".txt", ".rst") and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
assert isinstance(globs, list)
for glob in globs:
if path.check(fnmatch=glob):
if path.match(path_pattern=glob):
return True
return False

Expand All @@ -80,7 +104,9 @@ def _is_doctest(config, path, parent):
)


def _split_into_body_and_options(section_content):
def _split_into_body_and_options(
section_content: str,
) -> Tuple[str, Optional[str], Dict[int, bool]]:
"""Parse the the full content of a directive and split it.
It is split into a string, where the options (:options:, :hide: and
Expand Down Expand Up @@ -119,12 +145,14 @@ def _split_into_body_and_options(section_content):
for line in lines:
stripped = line.strip()
if _OPTION_SKIPIF_RE.match(stripped):
skipif_expr = _OPTION_SKIPIF_RE.match(stripped).group(1)
skipif_match = _OPTION_SKIPIF_RE.match(stripped)
assert skipif_match is not None
skipif_expr = skipif_match.group(1)
i += 1
elif _OPTION_DIRECTIVE_RE.match(stripped):
option_strings = (
_OPTION_DIRECTIVE_RE.match(stripped).group(1).replace(",", " ").split()
)
directive_match = _OPTION_DIRECTIVE_RE.match(stripped)
assert directive_match is not None
option_strings = directive_match.group(1).replace(",", " ").split()
for option in option_strings:
if (
option[0] not in "+-"
Expand Down Expand Up @@ -153,7 +181,9 @@ def _split_into_body_and_options(section_content):
return body, skipif_expr, flag_settings


def _get_next_textoutputsections(sections, index):
def _get_next_textoutputsections(
sections: List["Section"], index: int
) -> Iterator["Section"]:
"""Yield successive TESTOUTPUT sections."""
for j in range(index, len(sections)):
section = sections[j]
Expand All @@ -163,8 +193,17 @@ def _get_next_textoutputsections(sections, index):
break


SectionGroups = Optional[List[str]]


class Section:
def __init__(self, directive, content, lineno, groups=None):
def __init__(
self,
directive: SphinxDoctestDirectives,
content: str,
lineno: int,
groups: SectionGroups = None,
) -> None:
super().__init__()
self.directive = directive
self.groups = groups
Expand All @@ -180,14 +219,16 @@ def __init__(self, directive, content, lineno, groups=None):
self.options = options


def get_sections(docstring):
def get_sections(docstring: str) -> List[Union[Any, Section]]:
lines = textwrap.dedent(docstring).splitlines()
sections = []

def _get_indentation(line):
def _get_indentation(line: str) -> int:
return len(line) - len(line.lstrip())

def add_match(directive, i, j, groups):
def add_match(
directive: SphinxDoctestDirectives, i: int, j: int, groups: SectionGroups
) -> None:
sections.append(
Section(
directive,
Expand Down Expand Up @@ -227,28 +268,32 @@ def add_match(directive, i, j, groups):
return sections


def docstring2examples(docstring, globs=None):
def docstring2examples(
docstring: str, globs: Optional[GlobDict] = None
) -> List[Union[Any, doctest.Example]]:
"""
Parse all sphinx test directives in the docstring and create a
list of examples.
"""
# TODO subclass doctest.DocTestParser instead?

if not globs:
if globs is None:
globs = {}

sections = get_sections(docstring)

def get_testoutput_section_data(section):
def get_testoutput_section_data(
section: "Section",
) -> Tuple[str, Dict[int, bool], int, Optional[Any]]:
want = section.body
exc_msg = None
options = {}
options: Dict[int, bool] = {}

if section.skipif_expr and eval(section.skipif_expr, globs):
want = ""
else:
options = section.options
match = doctest.DocTestParser._EXCEPTION_RE.match(want)
match = doctest.DocTestParser._EXCEPTION_RE.match(want) # type: ignore
if match:
exc_msg = match.group("msg")

Expand Down Expand Up @@ -302,7 +347,13 @@ class SphinxDocTestRunner(doctest.DebugRunner):
`compile` function instead of 'exec'.
"""

def _DocTestRunner__run(self, test, compileflags, out):
_checker: "doctest.OutputChecker"
debugger: "pdb.Pdb"

def _DocTestRunner__run(
self, test: doctest.DocTest, compileflags: int, out: "_Out"
) -> doctest.TestResults:

"""
Run the examples in `test`.
Expand Down Expand Up @@ -375,8 +426,9 @@ def _DocTestRunner__run(self, test, compileflags, out):
exception = sys.exc_info()
self.debugger.set_continue() # ==== Example Finished ====

got = self._fakeout.getvalue() # the actual output
self._fakeout.truncate(0)
# the actual output
got = self._fakeout.getvalue() # type: ignore
self._fakeout.truncate(0) # type: ignore
outcome = FAILURE # guilty until proved innocent or insane

# If the example executed without raising any exceptions,
Expand All @@ -389,7 +441,7 @@ def _DocTestRunner__run(self, test, compileflags, out):
else:
exc_msg = traceback.format_exception_only(*exception[:2])[-1]
if not quiet:
got += doctest._exception_traceback(exception)
got += doctest._exception_traceback(exception) # type:ignore

# If `example.exc_msg` is None, then we weren't expecting
# an exception.
Expand All @@ -403,8 +455,10 @@ def _DocTestRunner__run(self, test, compileflags, out):
# Another chance if they didn't care about the detail.
elif self.optionflags & doctest.IGNORE_EXCEPTION_DETAIL:
if check(
doctest._strip_exception_details(example.exc_msg),
doctest._strip_exception_details(exc_msg),
doctest._strip_exception_details( # type:ignore
example.exc_msg,
),
doctest._strip_exception_details(exc_msg), # type:ignore
self.optionflags,
):
outcome = SUCCESS
Expand All @@ -419,7 +473,14 @@ def _DocTestRunner__run(self, test, compileflags, out):
failures += 1
elif outcome is BOOM:
if not quiet:
self.report_unexpected_exception(out, test, example, exception)
assert exception is not None
assert out is not None
self.report_unexpected_exception(
out,
test,
example,
exception, # type:ignore
)
failures += 1
else:
assert False, ("unknown outcome", outcome)
Expand All @@ -431,12 +492,19 @@ def _DocTestRunner__run(self, test, compileflags, out):
self.optionflags = original_optionflags

# Record and return the number of failures and tries.
self._DocTestRunner__record_outcome(test, failures, tries)
self._DocTestRunner__record_outcome(test, failures, tries) # type:ignore
return doctest.TestResults(failures, tries)


class SphinxDocTestParser:
def get_doctest(self, docstring, globs, name, filename, lineno):
def get_doctest(
self,
docstring: str,
globs: Dict[str, Any],
name: str,
filename: str,
lineno: int,
) -> doctest.DocTest:
# TODO document why we need to overwrite? get_doctest
return doctest.DocTest(
examples=docstring2examples(docstring, globs=globs),
Expand All @@ -451,16 +519,16 @@ def get_doctest(self, docstring, globs, name, filename, lineno):
class SphinxDoctestTextfile(pytest.Module):
obj = None

def collect(self):
def collect(self) -> Iterator[_pytest.doctest.DoctestItem]:
# inspired by doctest.testfile; ideally we would use it directly,
# but it doesn't support passing a custom checker
encoding = self.config.getini("doctest_encoding")
text = self.fspath.read_text(encoding)
name = self.fspath.basename

optionflags = _pytest.doctest.get_optionflags(self)
optionflags = _pytest.doctest.get_optionflags(self) # type:ignore
runner = SphinxDocTestRunner(
verbose=0,
verbose=False,
optionflags=optionflags,
checker=_pytest.doctest._get_checker(),
)
Expand All @@ -476,12 +544,15 @@ def collect(self):

if test.examples:
yield DoctestItem.from_parent(
parent=self, name=test.name, runner=runner, dtest=test
parent=self, # type:ignore
name=test.name,
runner=runner,
dtest=test,
)


class SphinxDoctestModule(pytest.Module):
def collect(self):
def collect(self) -> Iterator[_pytest.doctest.DoctestItem]:
if self.fspath.basename == "conftest.py":
module = self.config.pluginmanager._importconftest(
self.path,
Expand All @@ -496,7 +567,7 @@ def collect(self):
pytest.skip("unable to import module %r" % self.path)
else:
raise
optionflags = _pytest.doctest.get_optionflags(self)
optionflags = _pytest.doctest.get_optionflags(self) # type:ignore

class MockAwareDocTestFinder(doctest.DocTestFinder):
"""
Expand All @@ -508,11 +579,20 @@ class MockAwareDocTestFinder(doctest.DocTestFinder):
fix taken from https://github.com/pytest-dev/pytest/pull/4212/
"""

def _find(self, tests, obj, name, module, source_lines, globs, seen):
def _find(
self,
tests: List[doctest.DocTest],
obj: str,
name: str,
module: Any,
source_lines: Optional[List[str]],
globs: GlobDict,
seen: Dict[int, int],
) -> None:
if _is_mocked(obj):
return
with _patch_unwrap_mock_aware():
doctest.DocTestFinder._find(
doctest.DocTestFinder._find( # type:ignore
self,
tests,
obj,
Expand All @@ -524,18 +604,23 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
)

if sys.version_info < (3, 10):
finder = MockAwareDocTestFinder(parser=SphinxDocTestParser())
finder = MockAwareDocTestFinder(
parser=SphinxDocTestParser() # type:ignore
)
else:
finder = doctest.DocTestFinder(parser=SphinxDocTestParser())

runner = SphinxDocTestRunner(
verbose=0,
verbose=False,
optionflags=optionflags,
checker=_pytest.doctest._get_checker(),
)

for test in finder.find(module, module.__name__):
if test.examples:
yield DoctestItem.from_parent(
parent=self, name=test.name, runner=runner, dtest=test
parent=self, # type: ignore
name=test.name,
runner=runner,
dtest=test,
)
Loading

0 comments on commit ffe5aaa

Please sign in to comment.