From dfc773e997b7323209745eba2b0f235175b4719d Mon Sep 17 00:00:00 2001 From: Tobias Deiminger Date: Thu, 9 Dec 2021 08:22:58 +0100 Subject: [PATCH] Experimental: Use param ids for reordering items Determining if parameters are "the same" is difficult for pytest: - some user provided parameters may not be comparable/hashable - some are, but their equality/hash is determined by identity rather than by value, which is likely not a good grouping criterion - implicit fallbacks are likely to surprise the the user To solve this, we now rely on the parameter id as it was given via pytest.param(id=...,) or autogenerated by pytest. This way a user can explicitely tell pytest which parameters are "the same". This commit is somewhat larger, becuase we must take care that SubRequests and reordering agree in their meaning of "the same". --- src/_pytest/fixtures.py | 40 +++++--------- src/_pytest/python.py | 106 +++++++++++++++++++++++++------------ testing/python/fixtures.py | 20 +++---- 3 files changed, 96 insertions(+), 70 deletions(-) diff --git a/src/_pytest/fixtures.py b/src/_pytest/fixtures.py index e2abbaee287..78f8e3f9e93 100644 --- a/src/_pytest/fixtures.py +++ b/src/_pytest/fixtures.py @@ -5,7 +5,6 @@ import warnings from collections import defaultdict from collections import deque -from collections.abc import Hashable from contextlib import suppress from pathlib import Path from types import TracebackType @@ -241,24 +240,6 @@ def getfixturemarker(obj: object) -> Optional["FixtureFunctionMarker"]: _Key = Tuple[object, ...] -@attr.s(auto_attribs=True, eq=False) -class SafeHashWrapper: - obj: Any - - def __eq__(self, other) -> Any: - try: - res = self.obj == other - bool(res) - return res - except Exception: - return id(self.obj) == id(other) - - def __hash__(self) -> Any: - if isinstance(self.obj, Hashable): - return hash(self.obj) - return hash(id(self.obj)) - - def get_parametrized_fixture_keys(item: nodes.Item, scope: Scope) -> Iterator[_Key]: """Return list of keys for all parametrized arguments which match the specified scope.""" @@ -272,19 +253,18 @@ def get_parametrized_fixture_keys(item: nodes.Item, scope: Scope) -> Iterator[_K # cs.indices.items() is random order of argnames. Need to # sort this so that different calls to # get_parametrized_fixture_keys will be deterministic. - for argname, param_index in sorted(cs.indices.items()): + for argname, param_id in sorted(cs.ids.items()): if cs._arg2scope[argname] != scope: continue - param = SafeHashWrapper(cs.params.get(argname, param_index)) if scope is Scope.Session: - key: _Key = (argname, param) + key: _Key = (argname, param_id) elif scope is Scope.Package: - key = (argname, param, item.path.parent) + key = (argname, param_id, item.path.parent) elif scope is Scope.Module: - key = (argname, param, item.path) + key = (argname, param_id, item.path) elif scope is Scope.Class: item_cls = item.cls # type: ignore[attr-defined] - key = (argname, param, item.path, item_cls) + key = (argname, param_id, item.path, item_cls) else: assert_never(scope) yield key @@ -658,6 +638,7 @@ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None: except (AttributeError, ValueError): param = NOTSET param_index = 0 + param_id = "" has_params = fixturedef.params is not None fixtures_not_supported = getattr(funcitem, "nofuncargs", False) if has_params and fixtures_not_supported: @@ -697,13 +678,14 @@ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None: fail(msg, pytrace=False) else: param_index = funcitem.callspec.indices[argname] + param_id = funcitem.callspec.ids[argname] # If a parametrize invocation set a scope it will override # the static scope defined with the fixture function. with suppress(KeyError): scope = funcitem.callspec._arg2scope[argname] subrequest = SubRequest( - self, scope, param, param_index, fixturedef, _ispytest=True + self, scope, param, param_index, param_id, fixturedef, _ispytest=True ) # Check if a higher-level scoped fixture accesses a lower level one. @@ -788,6 +770,7 @@ def __init__( scope: Scope, param: Any, param_index: int, + param_id: str, fixturedef: "FixtureDef[object]", *, _ispytest: bool = False, @@ -798,6 +781,7 @@ def __init__( if param is not NOTSET: self.param = param self.param_index = param_index + self.param_id = param_id self._scope = scope self._fixturedef = fixturedef self._pyfuncitem = request._pyfuncitem @@ -1072,7 +1056,7 @@ def execute(self, request: SubRequest) -> FixtureValue: # note: comparison with `==` can fail (or be expensive) for e.g. # numpy arrays (#6497). cache_key = self.cached_result[1] - if my_cache_key is cache_key: + if my_cache_key == cache_key: if self.cached_result[2] is not None: _, val, tb = self.cached_result[2] raise val.with_traceback(tb) @@ -1089,7 +1073,7 @@ def execute(self, request: SubRequest) -> FixtureValue: return result def cache_key(self, request: SubRequest) -> object: - return request.param_index if not hasattr(request, "param") else request.param + return request.param_id def __repr__(self) -> str: return "".format( diff --git a/src/_pytest/python.py b/src/_pytest/python.py index d9fccde9a95..f81c8958ac0 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -9,6 +9,7 @@ import warnings from collections import Counter from collections import defaultdict +from collections import OrderedDict from functools import partial from pathlib import Path from typing import Any @@ -904,6 +905,55 @@ def hasnew(obj: object) -> bool: return False +class ParamArgIdSet: + @staticmethod + def from_idvalset( + idx: int, + parameterset: ParameterSet, + argnames: Iterable[str], + idfn: Optional[Callable[[Any], Optional[object]]], + ids: Optional[List[Union[None, str]]], + nodeid: Optional[str], + config: Optional[Config], + ) -> "ParamArgIdSet": + if parameterset.id is not None: + return ParamArgIdSet( + parameterset.id, {argname: parameterset.id for argname in argnames} + ) + id = None if ids is None or idx >= len(ids) else ids[idx] + if id is None: + this_id = [] + this_id_by_arg = {} + for val, argname in zip(parameterset.values, argnames): + arg_id = _idval(val, argname, idx, idfn, nodeid=nodeid, config=config) + this_id.append(arg_id) + this_id_by_arg[argname] = arg_id + return ParamArgIdSet("-".join(this_id), this_id_by_arg) + parameter_id = _ascii_escaped_by_config(id, config) + return ParamArgIdSet( + parameter_id, {argname: parameter_id for argname in argnames} + ) + + def __init__(self, id: str, ids_by_arg: Dict[str, str]): + self.id = id + self.ids_by_arg = ids_by_arg + self.suffix = "" + + def __eq__(self, other): + if isinstance(other, str): + return str(self) == other + return (self.id, self.suffix) == (other.id, other.suffix) + + def __hash__(self): + return hash((self.id, self.suffix)) + + def __str__(self) -> str: + return f"{self.id}{self.suffix}" + + def __repr__(self) -> str: + return str(self) + + @final @attr.s(frozen=True, slots=True, auto_attribs=True) class CallSpec2: @@ -922,6 +972,8 @@ class CallSpec2: params: Dict[str, object] = attr.Factory(dict) # arg name -> arg index. indices: Dict[str, int] = attr.Factory(dict) + # arg name -> arg id + ids: Dict[str, str] = attr.Factory(OrderedDict) # Used for sorting parametrized resources. _arg2scope: Dict[str, Scope] = attr.Factory(dict) # Parts which will be added to the item's name in `[..]` separated by "-". @@ -935,7 +987,7 @@ def setmulti( valtypes: Mapping[str, "Literal['params', 'funcargs']"], argnames: Iterable[str], valset: Iterable[object], - id: str, + id: ParamArgIdSet, marks: Iterable[Union[Mark, MarkDecorator]], scope: Scope, param_index: int, @@ -943,6 +995,7 @@ def setmulti( funcargs = self.funcargs.copy() params = self.params.copy() indices = self.indices.copy() + ids = self.ids.copy() arg2scope = self._arg2scope.copy() for arg, val in zip(argnames, valset): if arg in params or arg in funcargs: @@ -955,13 +1008,15 @@ def setmulti( else: assert_never(valtype_for_arg) indices[arg] = param_index + ids[arg] = id.ids_by_arg[arg] arg2scope[arg] = scope return CallSpec2( funcargs=funcargs, params=params, arg2scope=arg2scope, indices=indices, - idlist=[*self._idlist, id], + ids=ids, + idlist=[*self._idlist, str(id)], marks=[*self.marks, *normalize_mark_list(marks)], ) @@ -1128,20 +1183,26 @@ def parametrize( if generated_ids is not None: ids = generated_ids - ids = self._resolve_arg_ids( + resolved_ids = self._resolve_arg_ids( argnames, ids, parameters, nodeid=self.definition.nodeid ) # Store used (possibly generated) ids with parametrize Marks. if _param_mark and _param_mark._param_ids_from and generated_ids is None: - object.__setattr__(_param_mark._param_ids_from, "_param_ids_generated", ids) + object.__setattr__( + _param_mark._param_ids_from, + "_param_ids_generated", + [str(param_id) for param_id in resolved_ids], + ) # Create the new calls: if we are parametrize() multiple times (by applying the decorator # more than once) then we accumulate those calls generating the cartesian product # of all calls. newcalls = [] for callspec in self._calls or [CallSpec2()]: - for param_index, (param_id, param_set) in enumerate(zip(ids, parameters)): + for param_index, (param_id, param_set) in enumerate( + zip(resolved_ids, parameters) + ): newcallspec = callspec.setmulti( valtypes=arg_values_types, argnames=argnames, @@ -1165,7 +1226,7 @@ def _resolve_arg_ids( ], parameters: Sequence[ParameterSet], nodeid: str, - ) -> List[str]: + ) -> List[ParamArgIdSet]: """Resolve the actual ids for the given argnames, based on the ``ids`` parameter given to ``parametrize``. @@ -1385,28 +1446,6 @@ def _idval( return str(argname) + str(idx) -def _idvalset( - idx: int, - parameterset: ParameterSet, - argnames: Iterable[str], - idfn: Optional[Callable[[Any], Optional[object]]], - ids: Optional[List[Union[None, str]]], - nodeid: Optional[str], - config: Optional[Config], -) -> str: - if parameterset.id is not None: - return parameterset.id - id = None if ids is None or idx >= len(ids) else ids[idx] - if id is None: - this_id = [ - _idval(val, argname, idx, idfn, nodeid=nodeid, config=config) - for val, argname in zip(parameterset.values, argnames) - ] - return "-".join(this_id) - else: - return _ascii_escaped_by_config(id, config) - - def idmaker( argnames: Iterable[str], parametersets: Iterable[ParameterSet], @@ -1416,7 +1455,7 @@ def idmaker( nodeid: Optional[str] = None, ) -> List[str]: resolved_ids = [ - _idvalset( + ParamArgIdSet.from_idvalset( valindex, parameterset, argnames, idfn, ids, config=config, nodeid=nodeid ) for valindex, parameterset in enumerate(parametersets) @@ -1427,16 +1466,17 @@ def idmaker( if len(unique_ids) != len(resolved_ids): # Record the number of occurrences of each test ID. - test_id_counts = Counter(resolved_ids) + test_id_counts = Counter([str(resolved_id) for resolved_id in resolved_ids]) # Map the test ID to its next suffix. test_id_suffixes: Dict[str, int] = defaultdict(int) # Suffix non-unique IDs to make them unique. for index, test_id in enumerate(resolved_ids): - if test_id_counts[test_id] > 1: - resolved_ids[index] = f"{test_id}{test_id_suffixes[test_id]}" - test_id_suffixes[test_id] += 1 + current_test_id = str(test_id) + if test_id_counts[current_test_id] > 1: + resolved_ids[index].suffix = test_id_suffixes[current_test_id] + test_id_suffixes[current_test_id] += 1 return resolved_ids diff --git a/testing/python/fixtures.py b/testing/python/fixtures.py index 9fb095b2455..43eb172a579 100644 --- a/testing/python/fixtures.py +++ b/testing/python/fixtures.py @@ -1320,25 +1320,27 @@ def test_optimize_by_reorder_indirect(self, pytester: Pytester) -> None: @pytest.fixture(scope="session") def fix(request): - print(f'prepare foo-%s' % request.param) - yield request.param - print(f'teardown foo-%s' % request.param) + value = request.param["data"] if isinstance(request.param, dict) else request.param + print(f'prepare foo-%s' % value) + yield value + print(f'teardown foo-%s' % value) - @pytest.mark.parametrize("fix", ["data1", "data2"], indirect=True) + @pytest.mark.parametrize("fix", [1, pytest.param({"data": 2}, id="2")], indirect=True) def test1(fix): pass - @pytest.mark.parametrize("fix", ["data2", "data1"], indirect=True) + @pytest.mark.parametrize("fix", [pytest.param({"data": 2}, id="2"), 1], indirect=True) def test2(fix): pass """ ) + # pytest.param({"data": 2}, id="userid2") result = pytester.runpytest("-s") output = result.stdout.str() - assert output.count("prepare foo-data1") == 1 - assert output.count("prepare foo-data2") == 1 - assert output.count("teardown foo-data1") == 1 - assert output.count("teardown foo-data2") == 1 + assert output.count("prepare foo-1") == 1 + assert output.count("prepare foo-2") == 1 + assert output.count("teardown foo-1") == 1 + assert output.count("teardown foo-2") == 1 def test_funcarg_parametrized_and_used_twice(self, pytester: Pytester) -> None: pytester.makepyfile(