Skip to content

Commit

Permalink
Experimental: Use param ids for reordering items
Browse files Browse the repository at this point in the history
Determining if parameters are "the same" is difficult for pytest:
- some user provided parameters may not be comparable/hashable
- some are, but their equality/hash is determined by identity rather
  than by value, which is likely not a good grouping criterion
- implicit fallbacks are likely to surprise the the user

To solve this, we now rely on the parameter id as it was given
via pytest.param(id=...,) or autogenerated by pytest. This way a user
can explicitely tell pytest which parameters are "the same".

This commit is somewhat larger, becuase we must take care that SubRequests
and reordering agree in their meaning of "the same".
  • Loading branch information
Tobias Deiminger committed Dec 9, 2021
1 parent 2fdbc3e commit dfc773e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 70 deletions.
40 changes: 12 additions & 28 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from collections import defaultdict
from collections import deque
from collections.abc import Hashable
from contextlib import suppress
from pathlib import Path
from types import TracebackType
Expand Down Expand Up @@ -241,24 +240,6 @@ def getfixturemarker(obj: object) -> Optional["FixtureFunctionMarker"]:
_Key = Tuple[object, ...]


@attr.s(auto_attribs=True, eq=False)
class SafeHashWrapper:
obj: Any

def __eq__(self, other) -> Any:
try:
res = self.obj == other
bool(res)
return res
except Exception:
return id(self.obj) == id(other)

def __hash__(self) -> Any:
if isinstance(self.obj, Hashable):
return hash(self.obj)
return hash(id(self.obj))


def get_parametrized_fixture_keys(item: nodes.Item, scope: Scope) -> Iterator[_Key]:
"""Return list of keys for all parametrized arguments which match
the specified scope."""
Expand All @@ -272,19 +253,18 @@ def get_parametrized_fixture_keys(item: nodes.Item, scope: Scope) -> Iterator[_K
# cs.indices.items() is random order of argnames. Need to
# sort this so that different calls to
# get_parametrized_fixture_keys will be deterministic.
for argname, param_index in sorted(cs.indices.items()):
for argname, param_id in sorted(cs.ids.items()):
if cs._arg2scope[argname] != scope:
continue
param = SafeHashWrapper(cs.params.get(argname, param_index))
if scope is Scope.Session:
key: _Key = (argname, param)
key: _Key = (argname, param_id)

Check warning on line 260 in src/_pytest/fixtures.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/fixtures.py#L260

Added line #L260 was not covered by tests
elif scope is Scope.Package:
key = (argname, param, item.path.parent)
key = (argname, param_id, item.path.parent)
elif scope is Scope.Module:
key = (argname, param, item.path)
key = (argname, param_id, item.path)
elif scope is Scope.Class:
item_cls = item.cls # type: ignore[attr-defined]
key = (argname, param, item.path, item_cls)
key = (argname, param_id, item.path, item_cls)
else:
assert_never(scope)
yield key
Expand Down Expand Up @@ -658,6 +638,7 @@ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None:
except (AttributeError, ValueError):
param = NOTSET
param_index = 0
param_id = ""
has_params = fixturedef.params is not None
fixtures_not_supported = getattr(funcitem, "nofuncargs", False)
if has_params and fixtures_not_supported:
Expand Down Expand Up @@ -697,13 +678,14 @@ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None:
fail(msg, pytrace=False)
else:
param_index = funcitem.callspec.indices[argname]
param_id = funcitem.callspec.ids[argname]
# If a parametrize invocation set a scope it will override
# the static scope defined with the fixture function.
with suppress(KeyError):
scope = funcitem.callspec._arg2scope[argname]

subrequest = SubRequest(
self, scope, param, param_index, fixturedef, _ispytest=True
self, scope, param, param_index, param_id, fixturedef, _ispytest=True
)

# Check if a higher-level scoped fixture accesses a lower level one.
Expand Down Expand Up @@ -788,6 +770,7 @@ def __init__(
scope: Scope,
param: Any,
param_index: int,
param_id: str,
fixturedef: "FixtureDef[object]",
*,
_ispytest: bool = False,
Expand All @@ -798,6 +781,7 @@ def __init__(
if param is not NOTSET:
self.param = param
self.param_index = param_index
self.param_id = param_id
self._scope = scope
self._fixturedef = fixturedef
self._pyfuncitem = request._pyfuncitem
Expand Down Expand Up @@ -1072,7 +1056,7 @@ def execute(self, request: SubRequest) -> FixtureValue:
# note: comparison with `==` can fail (or be expensive) for e.g.
# numpy arrays (#6497).
cache_key = self.cached_result[1]
if my_cache_key is cache_key:
if my_cache_key == cache_key:
if self.cached_result[2] is not None:
_, val, tb = self.cached_result[2]
raise val.with_traceback(tb)
Expand All @@ -1089,7 +1073,7 @@ def execute(self, request: SubRequest) -> FixtureValue:
return result

def cache_key(self, request: SubRequest) -> object:
return request.param_index if not hasattr(request, "param") else request.param
return request.param_id

def __repr__(self) -> str:
return "<FixtureDef argname={!r} scope={!r} baseid={!r}>".format(
Expand Down
106 changes: 73 additions & 33 deletions src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections import Counter
from collections import defaultdict
from collections import OrderedDict
from functools import partial
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -904,6 +905,55 @@ def hasnew(obj: object) -> bool:
return False


class ParamArgIdSet:
@staticmethod
def from_idvalset(
idx: int,
parameterset: ParameterSet,
argnames: Iterable[str],
idfn: Optional[Callable[[Any], Optional[object]]],
ids: Optional[List[Union[None, str]]],
nodeid: Optional[str],
config: Optional[Config],
) -> "ParamArgIdSet":
if parameterset.id is not None:
return ParamArgIdSet(
parameterset.id, {argname: parameterset.id for argname in argnames}
)
id = None if ids is None or idx >= len(ids) else ids[idx]
if id is None:
this_id = []
this_id_by_arg = {}
for val, argname in zip(parameterset.values, argnames):
arg_id = _idval(val, argname, idx, idfn, nodeid=nodeid, config=config)
this_id.append(arg_id)
this_id_by_arg[argname] = arg_id
return ParamArgIdSet("-".join(this_id), this_id_by_arg)
parameter_id = _ascii_escaped_by_config(id, config)
return ParamArgIdSet(
parameter_id, {argname: parameter_id for argname in argnames}
)

def __init__(self, id: str, ids_by_arg: Dict[str, str]):
self.id = id
self.ids_by_arg = ids_by_arg
self.suffix = ""

def __eq__(self, other):
if isinstance(other, str):
return str(self) == other
return (self.id, self.suffix) == (other.id, other.suffix)

def __hash__(self):
return hash((self.id, self.suffix))

def __str__(self) -> str:
return f"{self.id}{self.suffix}"

def __repr__(self) -> str:
return str(self)


Check warning on line 956 in src/_pytest/python.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/python.py#L956

Added line #L956 was not covered by tests
@final
@attr.s(frozen=True, slots=True, auto_attribs=True)
class CallSpec2:
Expand All @@ -922,6 +972,8 @@ class CallSpec2:
params: Dict[str, object] = attr.Factory(dict)
# arg name -> arg index.
indices: Dict[str, int] = attr.Factory(dict)
# arg name -> arg id
ids: Dict[str, str] = attr.Factory(OrderedDict)
# Used for sorting parametrized resources.
_arg2scope: Dict[str, Scope] = attr.Factory(dict)
# Parts which will be added to the item's name in `[..]` separated by "-".
Expand All @@ -935,14 +987,15 @@ def setmulti(
valtypes: Mapping[str, "Literal['params', 'funcargs']"],
argnames: Iterable[str],
valset: Iterable[object],
id: str,
id: ParamArgIdSet,
marks: Iterable[Union[Mark, MarkDecorator]],
scope: Scope,
param_index: int,
) -> "CallSpec2":
funcargs = self.funcargs.copy()
params = self.params.copy()
indices = self.indices.copy()
ids = self.ids.copy()
arg2scope = self._arg2scope.copy()
for arg, val in zip(argnames, valset):
if arg in params or arg in funcargs:
Expand All @@ -955,13 +1008,15 @@ def setmulti(
else:
assert_never(valtype_for_arg)
indices[arg] = param_index
ids[arg] = id.ids_by_arg[arg]
arg2scope[arg] = scope
return CallSpec2(
funcargs=funcargs,
params=params,
arg2scope=arg2scope,
indices=indices,
idlist=[*self._idlist, id],
ids=ids,
idlist=[*self._idlist, str(id)],
marks=[*self.marks, *normalize_mark_list(marks)],
)

Expand Down Expand Up @@ -1128,20 +1183,26 @@ def parametrize(
if generated_ids is not None:
ids = generated_ids

ids = self._resolve_arg_ids(
resolved_ids = self._resolve_arg_ids(
argnames, ids, parameters, nodeid=self.definition.nodeid
)

# Store used (possibly generated) ids with parametrize Marks.
if _param_mark and _param_mark._param_ids_from and generated_ids is None:
object.__setattr__(_param_mark._param_ids_from, "_param_ids_generated", ids)
object.__setattr__(
_param_mark._param_ids_from,
"_param_ids_generated",
[str(param_id) for param_id in resolved_ids],
)

# Create the new calls: if we are parametrize() multiple times (by applying the decorator
# more than once) then we accumulate those calls generating the cartesian product
# of all calls.
newcalls = []
for callspec in self._calls or [CallSpec2()]:
for param_index, (param_id, param_set) in enumerate(zip(ids, parameters)):
for param_index, (param_id, param_set) in enumerate(
zip(resolved_ids, parameters)
):
newcallspec = callspec.setmulti(
valtypes=arg_values_types,
argnames=argnames,
Expand All @@ -1165,7 +1226,7 @@ def _resolve_arg_ids(
],
parameters: Sequence[ParameterSet],
nodeid: str,
) -> List[str]:
) -> List[ParamArgIdSet]:
"""Resolve the actual ids for the given argnames, based on the ``ids`` parameter given
to ``parametrize``.
Expand Down Expand Up @@ -1385,28 +1446,6 @@ def _idval(
return str(argname) + str(idx)


def _idvalset(
idx: int,
parameterset: ParameterSet,
argnames: Iterable[str],
idfn: Optional[Callable[[Any], Optional[object]]],
ids: Optional[List[Union[None, str]]],
nodeid: Optional[str],
config: Optional[Config],
) -> str:
if parameterset.id is not None:
return parameterset.id
id = None if ids is None or idx >= len(ids) else ids[idx]
if id is None:
this_id = [
_idval(val, argname, idx, idfn, nodeid=nodeid, config=config)
for val, argname in zip(parameterset.values, argnames)
]
return "-".join(this_id)
else:
return _ascii_escaped_by_config(id, config)


def idmaker(
argnames: Iterable[str],
parametersets: Iterable[ParameterSet],
Expand All @@ -1416,7 +1455,7 @@ def idmaker(
nodeid: Optional[str] = None,
) -> List[str]:
resolved_ids = [
_idvalset(
ParamArgIdSet.from_idvalset(
valindex, parameterset, argnames, idfn, ids, config=config, nodeid=nodeid
)
for valindex, parameterset in enumerate(parametersets)
Expand All @@ -1427,16 +1466,17 @@ def idmaker(
if len(unique_ids) != len(resolved_ids):

# Record the number of occurrences of each test ID.
test_id_counts = Counter(resolved_ids)
test_id_counts = Counter([str(resolved_id) for resolved_id in resolved_ids])

# Map the test ID to its next suffix.
test_id_suffixes: Dict[str, int] = defaultdict(int)

# Suffix non-unique IDs to make them unique.
for index, test_id in enumerate(resolved_ids):
if test_id_counts[test_id] > 1:
resolved_ids[index] = f"{test_id}{test_id_suffixes[test_id]}"
test_id_suffixes[test_id] += 1
current_test_id = str(test_id)
if test_id_counts[current_test_id] > 1:
resolved_ids[index].suffix = test_id_suffixes[current_test_id]
test_id_suffixes[current_test_id] += 1

return resolved_ids

Expand Down
20 changes: 11 additions & 9 deletions testing/python/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,25 +1320,27 @@ def test_optimize_by_reorder_indirect(self, pytester: Pytester) -> None:
@pytest.fixture(scope="session")
def fix(request):
print(f'prepare foo-%s' % request.param)
yield request.param
print(f'teardown foo-%s' % request.param)
value = request.param["data"] if isinstance(request.param, dict) else request.param
print(f'prepare foo-%s' % value)
yield value
print(f'teardown foo-%s' % value)
@pytest.mark.parametrize("fix", ["data1", "data2"], indirect=True)
@pytest.mark.parametrize("fix", [1, pytest.param({"data": 2}, id="2")], indirect=True)
def test1(fix):
pass
@pytest.mark.parametrize("fix", ["data2", "data1"], indirect=True)
@pytest.mark.parametrize("fix", [pytest.param({"data": 2}, id="2"), 1], indirect=True)
def test2(fix):
pass
"""
)
# pytest.param({"data": 2}, id="userid2")
result = pytester.runpytest("-s")
output = result.stdout.str()
assert output.count("prepare foo-data1") == 1
assert output.count("prepare foo-data2") == 1
assert output.count("teardown foo-data1") == 1
assert output.count("teardown foo-data2") == 1
assert output.count("prepare foo-1") == 1
assert output.count("prepare foo-2") == 1
assert output.count("teardown foo-1") == 1
assert output.count("teardown foo-2") == 1

def test_funcarg_parametrized_and_used_twice(self, pytester: Pytester) -> None:
pytester.makepyfile(
Expand Down

0 comments on commit dfc773e

Please sign in to comment.