From 67a8cc99198007d70a439148d2124814f08693dc Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Aug 2023 00:18:02 +0100 Subject: [PATCH 1/6] Infer ParamSpec constraint from arguments --- mypy/checkexpr.py | 45 +++++++++--- mypy/constraints.py | 61 ++++++++++++++-- mypy/expandtype.py | 2 - mypy/infer.py | 3 +- .../unit/check-parameter-specification.test | 72 +++++++++++++++---- test-data/unit/fixtures/paramspec.pyi | 3 +- 6 files changed, 157 insertions(+), 29 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 68ea7c30ed6f..8ea30e8454c9 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1927,7 +1927,7 @@ def infer_function_type_arguments( ) arg_pass_nums = self.get_arg_infer_passes( - callee_type.arg_types, formal_to_actual, len(args) + callee_type, arg_types, formal_to_actual, len(args) ) pass1_args: list[Type | None] = [] @@ -1941,6 +1941,7 @@ def infer_function_type_arguments( callee_type, pass1_args, arg_kinds, + arg_names, formal_to_actual, context=self.argument_infer_context(), strict=self.chk.in_checked_function(), @@ -2001,6 +2002,7 @@ def infer_function_type_arguments( callee_type, arg_types, arg_kinds, + arg_names, formal_to_actual, context=self.argument_infer_context(), strict=self.chk.in_checked_function(), @@ -2080,6 +2082,7 @@ def infer_function_type_arguments_pass2( callee_type, arg_types, arg_kinds, + arg_names, formal_to_actual, context=self.argument_infer_context(), ) @@ -2092,7 +2095,11 @@ def argument_infer_context(self) -> ArgumentInferContext: ) def get_arg_infer_passes( - self, arg_types: list[Type], formal_to_actual: list[list[int]], num_actuals: int + self, + callee: CallableType, + arg_types: list[Type], + formal_to_actual: list[list[int]], + num_actuals: int, ) -> list[int]: """Return pass numbers for args for two-pass argument type inference. @@ -2103,8 +2110,22 @@ def get_arg_infer_passes( lambdas more effectively. """ res = [1] * num_actuals - for i, arg in enumerate(arg_types): - if arg.accept(ArgInferSecondPassQuery()): + for i, arg in enumerate(callee.arg_types): + skip_param_spec = False + for j in formal_to_actual[i]: + p_actual = get_proper_type(arg_types[j]) + if isinstance(p_actual, CallableType) and not p_actual.variables: + # This is an exception from the usual logic where we put generic Callable + # arguments in the second pass. If we have a non-generic actual, it is + # likely to infer good constraints, for example if we have: + # def run(Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... + # def test(x: int, y: int) -> int: ... + # run(test, 1, 2) + # we will use `test` for inference, since it will allow to infer also + # argument *names* for P <: [x: int, y: int]. + skip_param_spec = True + break + if arg.accept(ArgInferSecondPassQuery(skip_param_spec=skip_param_spec)): for j in formal_to_actual[i]: res[j] = 2 return res @@ -4832,7 +4853,7 @@ def infer_lambda_type_using_context( self.chk.fail(message_registry.CANNOT_INFER_LAMBDA_TYPE, e) return None, None - return callable_ctx, callable_ctx + return callable_ctx.copy_modified(arg_names=e.arg_names), callable_ctx def visit_super_expr(self, e: SuperExpr) -> Type: """Type check a super expression (non-lvalue).""" @@ -5846,22 +5867,30 @@ class ArgInferSecondPassQuery(types.BoolTypeQuery): a type variable. """ - def __init__(self) -> None: + def __init__(self, skip_param_spec: bool) -> None: super().__init__(types.ANY_STRATEGY) + self.skip_param_spec = skip_param_spec def visit_callable_type(self, t: CallableType) -> bool: - return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery()) + # TODO: we need to check only for type variables of original callable. + return self.query_types(t.arg_types) or t.accept( + HasTypeVarQuery(skip_param_spec=self.skip_param_spec) + ) class HasTypeVarQuery(types.BoolTypeQuery): """Visitor for querying whether a type has a type variable component.""" - def __init__(self) -> None: + def __init__(self, skip_param_spec: bool) -> None: super().__init__(types.ANY_STRATEGY) + self.skip_param_spec = skip_param_spec def visit_type_var(self, t: TypeVarType) -> bool: return True + def visit_param_spec(self, t: ParamSpecType) -> bool: + return not self.skip_param_spec + def has_erased_component(t: Type | None) -> bool: return t is not None and t.accept(HasErasedComponentsQuery()) diff --git a/mypy/constraints.py b/mypy/constraints.py index 04c3378ce16b..fa02e2f74301 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -96,6 +96,7 @@ def infer_constraints_for_callable( callee: CallableType, arg_types: Sequence[Type | None], arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], context: ArgumentInferContext, ) -> list[Constraint]: @@ -106,6 +107,20 @@ def infer_constraints_for_callable( constraints: list[Constraint] = [] mapper = ArgTypeExpander(context) + param_spec = callee.param_spec() + param_spec_arg_types = [] + param_spec_arg_names = [] + param_spec_arg_kinds = [] + + incomplete_star_mapping = False + for i, actuals in enumerate(formal_to_actual): + for actual in actuals: + if actual is None and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2): + # We can't use arguments to infer ParamSpec constraint, if only some + # are present in the current inference pass. + incomplete_star_mapping = True + break + for i, actuals in enumerate(formal_to_actual): if isinstance(callee.arg_types[i], UnpackType): unpack_type = callee.arg_types[i] @@ -165,11 +180,47 @@ def infer_constraints_for_callable( actual_type = mapper.expand_actual_type( actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] ) - # TODO: if callee has ParamSpec, we need to collect all actuals that map to star - # args and create single constraint between P and resulting Parameters instead. - c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) - constraints.extend(c) - + if ( + param_spec + and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2) + and not incomplete_star_mapping + ): + # If actual arguments are mapped to ParamSpec type, we can't infer individual + # constraints, instead store them and infer single constraint at the end. + # It is impossible to map actual kind to formal kind, so use some heuristic. + # This inference is used as a fallback, so relying on heuristic should be OK. + param_spec_arg_types.append( + mapper.expand_actual_type( + actual_arg_type, arg_kinds[actual], None, arg_kinds[actual] + ) + ) + param_spec_arg_kinds.append( + ARG_POS + if arg_kinds[actual] not in (ARG_STAR, ARG_STAR2) + else arg_kinds[actual] + ) + param_spec_arg_names.append(arg_names[actual] if arg_names else None) + else: + c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) + constraints.extend(c) + if ( + param_spec + and not any(c.type_var == param_spec.id for c in constraints) + and not incomplete_star_mapping + ): + # Use ParamSpec constraint from arguments only if there are no other constraints, + # since as explained above it is quite ad-hoc. + constraints.append( + Constraint( + param_spec, + SUPERTYPE_OF, + Parameters( + arg_types=param_spec_arg_types, + arg_kinds=param_spec_arg_kinds, + arg_names=param_spec_arg_names, + ), + ) + ) return constraints diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 0e98ed048197..686b1a318330 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -383,8 +383,6 @@ def visit_callable_type(self, t: CallableType) -> CallableType: t = t.expand_param_spec(repl) return t.copy_modified( arg_types=self.expand_types(t.arg_types), - arg_kinds=t.arg_kinds, - arg_names=t.arg_names, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) diff --git a/mypy/infer.py b/mypy/infer.py index f34087910e4b..ba4a1d2bc9b1 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -33,6 +33,7 @@ def infer_function_type_arguments( callee_type: CallableType, arg_types: Sequence[Type | None], arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], context: ArgumentInferContext, strict: bool = True, @@ -53,7 +54,7 @@ def infer_function_type_arguments( """ # Infer constraints. constraints = infer_constraints_for_callable( - callee_type, arg_types, arg_kinds, formal_to_actual, context + callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context ) # Solve constraints. diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index f523cb005a2c..9a4f67af2a0e 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -160,9 +160,8 @@ from typing import Callable, TypeVar from typing_extensions import ParamSpec P = ParamSpec('P') -R = TypeVar('R') -def f(x: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: +def f(x: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: return x(*args, **kwargs) def g(x: int, y: str) -> None: ... @@ -171,7 +170,24 @@ reveal_type(f(g, 1, y='x')) # N: Revealed type is "None" f(g, 'x', y='x') # E: Argument 2 to "f" has incompatible type "str"; expected "int" f(g, 1, y=1) # E: Argument "y" to "f" has incompatible type "int"; expected "str" f(g) # E: Missing positional arguments "x", "y" in call to "f" +[builtins fixtures/dict.pyi] + +[case testParamSpecFunctionGeneric] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec('P') +R = TypeVar('R') +def f(x: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + return x(*args, **kwargs) + +def g(x: int, y: str) -> None: ... + +reveal_type(f(g, 1, y='x')) # N: Revealed type is "None" +f(g, 'x', y='x') # E: Argument 1 to "f" has incompatible type "Callable[[int, str], None]"; expected "Callable[[str, str], None]" +f(g, 1, y=1) # E: Argument 1 to "f" has incompatible type "Callable[[int, str], None]"; expected "Callable[[int, int], None]" +f(g) # E: Argument 1 to "f" has incompatible type "Callable[[int, str], None]"; expected "Callable[[], None]" [builtins fixtures/dict.pyi] [case testParamSpecSpecialCase] @@ -347,14 +363,15 @@ P = ParamSpec('P') T = TypeVar('T') # Similar to atexit.register -def register(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Callable[P, T]: ... # N: "register" defined here +def register(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Callable[P, T]: ... def f(x: int) -> None: pass reveal_type(register(lambda: f(1))) # N: Revealed type is "def ()" -reveal_type(register(lambda x: f(x), x=1)) # N: Revealed type is "def (x: Any)" -register(lambda x: f(x)) # E: Missing positional argument "x" in call to "register" -register(lambda x: f(x), y=1) # E: Unexpected keyword argument "y" for "register" +reveal_type(register(lambda x: f(x), x=1)) # N: Revealed type is "def (x: Literal[1]?)" +register(lambda x: f(x)) # E: Cannot infer type of lambda \ + # E: Argument 1 to "register" has incompatible type "Callable[[Any], None]"; expected "Callable[[], None]" +register(lambda x: f(x), y=1) # E: Argument 1 to "register" has incompatible type "Callable[[Arg(int, 'x')], None]"; expected "Callable[[Arg(int, 'y')], None]" [builtins fixtures/dict.pyi] [case testParamSpecInvalidCalls] @@ -841,8 +858,7 @@ def f(x: int) -> int: reveal_type(A().func(f, 42)) # N: Revealed type is "builtins.int" -# TODO: this should reveal `int` -reveal_type(A().func(lambda x: x + x, 42)) # N: Revealed type is "Any" +reveal_type(A().func(lambda x: x + x, 42)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] [case testParamSpecConstraintOnOtherParamSpec] @@ -1287,7 +1303,6 @@ P = ParamSpec('P') class Some(Generic[P]): def call(self, *args: P.args, **kwargs: P.kwargs): ... -# TODO: this probably should be reported. def call(*args: P.args, **kwargs: P.kwargs): ... [builtins fixtures/paramspec.pyi] @@ -1564,7 +1579,41 @@ dec(test_with_bound)(0) # E: Value of type variable "T" of function cannot be " dec(test_with_bound)(A()) # OK [builtins fixtures/paramspec.pyi] +[case testParamSpecArgumentParamInferenceRegular] +from typing import TypeVar, Generic +from typing_extensions import ParamSpec + +P = ParamSpec("P") +class Foo(Generic[P]): + def call(self, *args: P.args, **kwargs: P.kwargs) -> None: ... +def test(*args: P.args, **kwargs: P.kwargs) -> Foo[P]: ... + +reveal_type(test(1, 2)) # N: Revealed type is "__main__.Foo[[Literal[1]?, Literal[2]?]]" +reveal_type(test(x=1, y=2)) # N: Revealed type is "__main__.Foo[[x: Literal[1]?, y: Literal[2]?]]" +ints = [1, 2, 3] +reveal_type(test(*ints)) # N: Revealed type is "__main__.Foo[[*builtins.int]]" +[builtins fixtures/paramspec.pyi] + +[case testParamSpecArgumentParamInferenceGeneric] +# flags: --new-type-inference +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +R = TypeVar("R") +def call(f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + return f(*args, **kwargs) + +T = TypeVar("T") +def identity(x: T) -> T: + return x + +reveal_type(call(identity, 2)) # N: Revealed type is "builtins.int" +y: int = call(identity, 2) +[builtins fixtures/paramspec.pyi] + [case testParamSpecNestedApplyNoCrash] +# flags: --new-type-inference from typing import Callable, TypeVar from typing_extensions import ParamSpec @@ -1572,7 +1621,6 @@ P = ParamSpec("P") T = TypeVar("T") def apply(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... -def test() -> None: ... -# TODO: avoid this error, although it may be non-trivial. -apply(apply, test) # E: Argument 2 to "apply" has incompatible type "Callable[[], None]"; expected "Callable[P, T]" +def test() -> int: ... +reveal_type(apply(apply, test)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] diff --git a/test-data/unit/fixtures/paramspec.pyi b/test-data/unit/fixtures/paramspec.pyi index 5e4b8564e238..9b0089f6a7e9 100644 --- a/test-data/unit/fixtures/paramspec.pyi +++ b/test-data/unit/fixtures/paramspec.pyi @@ -30,7 +30,8 @@ class list(Sequence[T], Generic[T]): def __iter__(self) -> Iterator[T]: ... class int: - def __neg__(self) -> 'int': ... + def __neg__(self) -> int: ... + def __add__(self, other: int) -> int: ... class bool(int): ... class float: ... From dacbb5fb3842c470683f44d7a6ba95f6f5d02f41 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Aug 2023 11:03:56 +0100 Subject: [PATCH 2/6] Handle corner case; update some tests --- mypy/checkexpr.py | 2 ++ mypy/expandtype.py | 2 ++ mypy/solve.py | 22 ++++++++++++----- .../unit/check-parameter-specification.test | 13 ++++++++++ test-data/unit/typexport-basic.test | 24 +++++++++---------- 5 files changed, 45 insertions(+), 18 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8ea30e8454c9..bb6674ec52ff 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4853,6 +4853,8 @@ def infer_lambda_type_using_context( self.chk.fail(message_registry.CANNOT_INFER_LAMBDA_TYPE, e) return None, None + # Type of lambda must have correct argument names, to prevent false + # negatives when lambdas appear in `ParamSpec` context. return callable_ctx.copy_modified(arg_names=e.arg_names), callable_ctx def visit_super_expr(self, e: SuperExpr) -> Type: diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 686b1a318330..0e98ed048197 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -383,6 +383,8 @@ def visit_callable_type(self, t: CallableType) -> CallableType: t = t.expand_param_spec(repl) return t.copy_modified( arg_types=self.expand_types(t.arg_types), + arg_kinds=t.arg_kinds, + arg_names=t.arg_names, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) diff --git a/mypy/solve.py b/mypy/solve.py index 4b2b899c2a8d..919f2bd7f56c 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -85,7 +85,7 @@ def solve_constraints( continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] - solution = solve_one(lowers, uppers) + solution = solve_one(originals[tv], lowers, uppers) # Do not leak type variables in non-polymorphic solutions. if solution is None or not get_vars( @@ -163,13 +163,17 @@ def solve_with_dependent( solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: - res = solve_iteratively(flat_batch, graph, lowers, uppers) + res = solve_iteratively(flat_batch, graph, lowers, uppers, originals) solutions.update(res) return solutions, [free_solutions[tv] for tv in free_vars] def solve_iteratively( - batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds + batch: list[TypeVarId], + graph: Graph, + lowers: Bounds, + uppers: Bounds, + originals: dict[TypeVarId, TypeVarLikeType], ) -> Solutions: """Solve transitive closure sequentially, updating upper/lower bounds after each step. @@ -195,7 +199,7 @@ def solve_iteratively( break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) - result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) + result = solve_one(originals[solvable_tv], lowers[solvable_tv], uppers[solvable_tv]) solutions[solvable_tv] = result if result is None: # TODO: support backtracking lower/upper bound choices and order within SCCs. @@ -228,7 +232,9 @@ def solve_iteratively( return solutions -def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: +def solve_one( + type_var: TypeVarLikeType, lowers: Iterable[Type], uppers: Iterable[Type] +) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" bottom: Type | None = None top: Type | None = None @@ -268,7 +274,11 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: return None elif top is None: candidate = bottom - elif is_subtype(bottom, top): + elif is_subtype(bottom, top, ignore_pos_arg_names=isinstance(type_var, ParamSpecType)): + # We ignore positional argument names to handle rare but important corner case, when + # constraints from both callable inference (more precise) and from ParamSpec vs arguments + # (more ad-hoc) are both present (e.g. if they were inferred at different nesting levels). + # In such case we conservatively solve [int] <: P <: [x: int] as P = [int]. candidate = bottom else: candidate = None diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 9a4f67af2a0e..d3273d6fb7da 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1624,3 +1624,16 @@ def apply(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... def test() -> int: ... reveal_type(apply(apply, test)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] + +[case testParamSpecNestedApplyPosVsNamed] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + +def apply(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: ... +def test(x: int) -> int: ... +apply(apply, test, x=42) # OK +apply(apply, test, 42) # Also OK (but requires some special casing) +[builtins fixtures/paramspec.pyi] diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index cd2afe2c1c75..c4c3a1d36f83 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -727,7 +727,7 @@ class A: pass class B: a = None # type: A [out] -LambdaExpr(2) : def (B) -> A +LambdaExpr(2) : def (x: B) -> A MemberExpr(2) : A NameExpr(2) : B @@ -756,7 +756,7 @@ class B: a = None # type: A [builtins fixtures/list.pyi] [out] -LambdaExpr(2) : def (B) -> builtins.list[A] +LambdaExpr(2) : def (x: B) -> builtins.list[A] ListExpr(2) : builtins.list[A] [case testLambdaAndHigherOrderFunction] @@ -775,7 +775,7 @@ map( CallExpr(9) : builtins.list[B] NameExpr(9) : def (f: def (A) -> B, a: builtins.list[A]) -> builtins.list[B] CallExpr(10) : B -LambdaExpr(10) : def (A) -> B +LambdaExpr(10) : def (x: A) -> B NameExpr(10) : def (a: A) -> B NameExpr(10) : builtins.list[A] NameExpr(10) : A @@ -795,7 +795,7 @@ map( [builtins fixtures/list.pyi] [out] NameExpr(10) : def (f: def (A) -> builtins.list[B], a: builtins.list[A]) -> builtins.list[B] -LambdaExpr(11) : def (A) -> builtins.list[B] +LambdaExpr(11) : def (x: A) -> builtins.list[B] ListExpr(11) : builtins.list[B] NameExpr(11) : def (a: A) -> B NameExpr(11) : builtins.list[A] @@ -817,7 +817,7 @@ map( -- context. Perhaps just fail instead? CallExpr(7) : builtins.list[Any] NameExpr(7) : def (f: builtins.list[def (A) -> Any], a: builtins.list[A]) -> builtins.list[Any] -LambdaExpr(8) : def (A) -> A +LambdaExpr(8) : def (x: A) -> A ListExpr(8) : builtins.list[def (A) -> Any] NameExpr(8) : A NameExpr(9) : builtins.list[A] @@ -838,7 +838,7 @@ map( [out] CallExpr(9) : builtins.list[B] NameExpr(9) : def (f: def (A) -> B, a: builtins.list[A]) -> builtins.list[B] -LambdaExpr(10) : def (A) -> B +LambdaExpr(10) : def (x: A) -> B MemberExpr(10) : B NameExpr(10) : A NameExpr(11) : builtins.list[A] @@ -860,7 +860,7 @@ map( CallExpr(9) : builtins.list[B] NameExpr(9) : def (f: def (A) -> B, a: builtins.list[A]) -> builtins.list[B] NameExpr(10) : builtins.list[A] -LambdaExpr(11) : def (A) -> B +LambdaExpr(11) : def (x: A) -> B MemberExpr(11) : B NameExpr(11) : A @@ -1212,7 +1212,7 @@ f( [builtins fixtures/list.pyi] [out] NameExpr(8) : Overload(def (x: builtins.int, f: def (builtins.int) -> builtins.int), def (x: builtins.str, f: def (builtins.str) -> builtins.str)) -LambdaExpr(9) : def (builtins.int) -> builtins.int +LambdaExpr(9) : def (x: builtins.int) -> builtins.int NameExpr(9) : builtins.int [case testExportOverloadArgTypeNested] @@ -1231,10 +1231,10 @@ f( lambda x: x) [builtins fixtures/list.pyi] [out] -LambdaExpr(9) : def (builtins.int) -> builtins.int -LambdaExpr(10) : def (builtins.int) -> builtins.int -LambdaExpr(12) : def (builtins.str) -> builtins.str -LambdaExpr(13) : def (builtins.str) -> builtins.str +LambdaExpr(9) : def (y: builtins.int) -> builtins.int +LambdaExpr(10) : def (x: builtins.int) -> builtins.int +LambdaExpr(12) : def (y: builtins.str) -> builtins.str +LambdaExpr(13) : def (x: builtins.str) -> builtins.str -- TODO -- From 2485aa59728c515bea36a7e2051511ada9e39718 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Aug 2023 14:31:49 +0100 Subject: [PATCH 3/6] Try better special-casing --- mypy/checkexpr.py | 33 +++++---- mypy/constraints.py | 71 +++++++++++-------- mypy/expandtype.py | 2 + mypy/solve.py | 6 +- mypy/subtypes.py | 4 ++ mypy/types.py | 22 ++++++ .../unit/check-parameter-specification.test | 33 ++++----- 7 files changed, 103 insertions(+), 68 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bb6674ec52ff..ec1143e32fcf 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1927,7 +1927,7 @@ def infer_function_type_arguments( ) arg_pass_nums = self.get_arg_infer_passes( - callee_type, arg_types, formal_to_actual, len(args) + callee_type, args, arg_types, formal_to_actual, len(args) ) pass1_args: list[Type | None] = [] @@ -2097,6 +2097,7 @@ def argument_infer_context(self) -> ArgumentInferContext: def get_arg_infer_passes( self, callee: CallableType, + args: list[Expression], arg_types: list[Type], formal_to_actual: list[list[int]], num_actuals: int, @@ -2112,9 +2113,10 @@ def get_arg_infer_passes( res = [1] * num_actuals for i, arg in enumerate(callee.arg_types): skip_param_spec = False - for j in formal_to_actual[i]: - p_actual = get_proper_type(arg_types[j]) - if isinstance(p_actual, CallableType) and not p_actual.variables: + p_formal = get_proper_type(callee.arg_types[i]) + if isinstance(p_formal, CallableType) and p_formal.param_spec(): + for j in formal_to_actual[i]: + p_actual = get_proper_type(arg_types[j]) # This is an exception from the usual logic where we put generic Callable # arguments in the second pass. If we have a non-generic actual, it is # likely to infer good constraints, for example if we have: @@ -2123,9 +2125,14 @@ def get_arg_infer_passes( # run(test, 1, 2) # we will use `test` for inference, since it will allow to infer also # argument *names* for P <: [x: int, y: int]. - skip_param_spec = True - break - if arg.accept(ArgInferSecondPassQuery(skip_param_spec=skip_param_spec)): + if ( + isinstance(p_actual, CallableType) + and not p_actual.variables + and not isinstance(args[j], LambdaExpr) + ): + skip_param_spec = True + break + if not skip_param_spec and arg.accept(ArgInferSecondPassQuery()): for j in formal_to_actual[i]: res[j] = 2 return res @@ -5869,29 +5876,25 @@ class ArgInferSecondPassQuery(types.BoolTypeQuery): a type variable. """ - def __init__(self, skip_param_spec: bool) -> None: + def __init__(self) -> None: super().__init__(types.ANY_STRATEGY) - self.skip_param_spec = skip_param_spec def visit_callable_type(self, t: CallableType) -> bool: # TODO: we need to check only for type variables of original callable. - return self.query_types(t.arg_types) or t.accept( - HasTypeVarQuery(skip_param_spec=self.skip_param_spec) - ) + return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery()) class HasTypeVarQuery(types.BoolTypeQuery): """Visitor for querying whether a type has a type variable component.""" - def __init__(self, skip_param_spec: bool) -> None: + def __init__(self) -> None: super().__init__(types.ANY_STRATEGY) - self.skip_param_spec = skip_param_spec def visit_type_var(self, t: TypeVarType) -> bool: return True def visit_param_spec(self, t: ParamSpecType) -> bool: - return not self.skip_param_spec + return True def has_erased_component(t: Type | None) -> bool: diff --git a/mypy/constraints.py b/mypy/constraints.py index fa02e2f74301..238b6cb36d04 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -194,10 +194,9 @@ def infer_constraints_for_callable( actual_arg_type, arg_kinds[actual], None, arg_kinds[actual] ) ) + actual_kind = arg_kinds[actual] param_spec_arg_kinds.append( - ARG_POS - if arg_kinds[actual] not in (ARG_STAR, ARG_STAR2) - else arg_kinds[actual] + ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind ) param_spec_arg_names.append(arg_names[actual] if arg_names else None) else: @@ -218,6 +217,7 @@ def infer_constraints_for_callable( arg_types=param_spec_arg_types, arg_kinds=param_spec_arg_kinds, arg_names=param_spec_arg_names, + imprecise_arg_kinds=True, ), ) ) @@ -1031,46 +1031,59 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: ) extra_tvars = True + # Compare prefixes as well + cactual_prefix = cactual.copy_modified( + arg_types=cactual.arg_types[:prefix_len], + arg_kinds=cactual.arg_kinds[:prefix_len], + arg_names=cactual.arg_names[:prefix_len], + ) + + for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): + if isinstance(a, ParamSpecType): + continue + res.extend(infer_constraints(t, a, neg_op(self.direction))) + if not cactual_ps: max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]) prefix_len = min(prefix_len, max_prefix_len) - res.append( - Constraint( - param_spec, - neg_op(self.direction), - Parameters( - arg_types=cactual.arg_types[prefix_len:], - arg_kinds=cactual.arg_kinds[prefix_len:], - arg_names=cactual.arg_names[prefix_len:], - variables=cactual.variables - if not type_state.infer_polymorphic - else [], - ), + # This logic matches top-level callable constraint exception, if we managed + # to get other constraints for ParamSpec, don't infer one with imprecise kinds + if not ( + any(c.type_var == param_spec.id for c in res) + and cactual.imprecise_arg_kinds + ): + res.append( + Constraint( + param_spec, + neg_op(self.direction), + Parameters( + arg_types=cactual.arg_types[prefix_len:], + arg_kinds=cactual.arg_kinds[prefix_len:], + arg_names=cactual.arg_names[prefix_len:], + variables=cactual.variables + if not type_state.infer_polymorphic + else [], + imprecise_arg_kinds=cactual.imprecise_arg_kinds, + ), + ) ) - ) else: - if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types): + if len(param_spec.prefix.arg_types) <= len( + cactual_ps.prefix.arg_types + ) and not ( + any(c.type_var == param_spec.id for c in res) + and cactual_ps.prefix.imprecise_arg_kinds + ): cactual_ps = cactual_ps.copy_modified( prefix=Parameters( arg_types=cactual_ps.prefix.arg_types[prefix_len:], arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:], arg_names=cactual_ps.prefix.arg_names[prefix_len:], + imprecise_arg_kinds=cactual_ps.prefix.imprecise_arg_kinds, ) ) res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps)) - # Compare prefixes as well - cactual_prefix = cactual.copy_modified( - arg_types=cactual.arg_types[:prefix_len], - arg_kinds=cactual.arg_kinds[:prefix_len], - arg_names=cactual.arg_names[:prefix_len], - ) - - for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): - if isinstance(a, ParamSpecType): - continue - res.extend(infer_constraints(t, a, neg_op(self.direction))) - template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type if template.type_guard is not None: template_ret_type = template.type_guard diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 0e98ed048197..3bfa06f15ea2 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -387,6 +387,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType: arg_names=t.arg_names, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), + imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds), ) elif isinstance(repl, ParamSpecType): # We're substituting one ParamSpec for another; this can mean that the prefix @@ -402,6 +403,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType: arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:], arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:], ret_type=t.ret_type.accept(self), + imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds), ) var_arg = t.var_arg() diff --git a/mypy/solve.py b/mypy/solve.py index 919f2bd7f56c..42893a31a314 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -274,11 +274,7 @@ def solve_one( return None elif top is None: candidate = bottom - elif is_subtype(bottom, top, ignore_pos_arg_names=isinstance(type_var, ParamSpecType)): - # We ignore positional argument names to handle rare but important corner case, when - # constraints from both callable inference (more precise) and from ParamSpec vs arguments - # (more ad-hoc) are both present (e.g. if they were inferred at different nesting levels). - # In such case we conservatively solve [int] <: P <: [x: int] as P = [int]. + elif is_subtype(bottom, top): candidate = bottom else: candidate = None diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 60fccc7e357c..abfa15913203 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1485,6 +1485,10 @@ def are_parameters_compatible( if right.is_ellipsis_args: return True + if right.imprecise_arg_kinds: + allow_partial_overlap = True + ignore_pos_arg_names = True + left_star = left.var_arg() left_star2 = left.kw_arg() right_star = right.var_arg() diff --git a/mypy/types.py b/mypy/types.py index 359ca713616b..c8df9a6493a8 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1559,6 +1559,7 @@ class Parameters(ProperType): "min_args", "is_ellipsis_args", "variables", + "imprecise_arg_kinds", ) def __init__( @@ -1569,6 +1570,7 @@ def __init__( *, variables: Sequence[TypeVarLikeType] | None = None, is_ellipsis_args: bool = False, + imprecise_arg_kinds: bool = False, line: int = -1, column: int = -1, ) -> None: @@ -1581,6 +1583,7 @@ def __init__( self.min_args = arg_kinds.count(ARG_POS) self.is_ellipsis_args = is_ellipsis_args self.variables = variables or [] + self.imprecise_arg_kinds = imprecise_arg_kinds def copy_modified( self, @@ -1590,6 +1593,7 @@ def copy_modified( *, variables: Bogus[Sequence[TypeVarLikeType]] = _dummy, is_ellipsis_args: Bogus[bool] = _dummy, + imprecise_arg_kinds: Bogus[bool] = _dummy, ) -> Parameters: return Parameters( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1599,6 +1603,11 @@ def copy_modified( is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args ), variables=variables if variables is not _dummy else self.variables, + imprecise_arg_kinds=( + imprecise_arg_kinds + if imprecise_arg_kinds is not _dummy + else self.imprecise_arg_kinds + ), ) # the following are copied from CallableType. Is there a way to decrease code duplication? @@ -1695,6 +1704,7 @@ def serialize(self) -> JsonDict: "arg_kinds": [int(x.value) for x in self.arg_kinds], "arg_names": self.arg_names, "variables": [tv.serialize() for tv in self.variables], + "imprecise_arg_kinds": self.imprecise_arg_kinds, } @classmethod @@ -1705,6 +1715,7 @@ def deserialize(cls, data: JsonDict) -> Parameters: [ArgKind(x) for x in data["arg_kinds"]], data["arg_names"], variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data["variables"]], + imprecise_arg_kinds=data["imprecise_arg_kinds"], ) def __hash__(self) -> int: @@ -1761,6 +1772,7 @@ class CallableType(FunctionLike): "type_guard", # T, if -> TypeGuard[T] (ret_type is bool in this case). "from_concatenate", # whether this callable is from a concatenate object # (this is used for error messages) + "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? ) @@ -1785,6 +1797,7 @@ def __init__( def_extras: dict[str, Any] | None = None, type_guard: Type | None = None, from_concatenate: bool = False, + imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, ) -> None: super().__init__(line, column) @@ -1811,6 +1824,7 @@ def __init__( self.special_sig = special_sig self.from_type_type = from_type_type self.from_concatenate = from_concatenate + self.imprecise_arg_kinds = imprecise_arg_kinds if not bound_args: bound_args = () self.bound_args = bound_args @@ -1853,6 +1867,7 @@ def copy_modified( def_extras: Bogus[dict[str, Any]] = _dummy, type_guard: Bogus[Type | None] = _dummy, from_concatenate: Bogus[bool] = _dummy, + imprecise_arg_kinds: Bogus[bool] = _dummy, unpack_kwargs: Bogus[bool] = _dummy, ) -> CT: modified = CallableType( @@ -1878,6 +1893,11 @@ def copy_modified( from_concatenate=( from_concatenate if from_concatenate is not _dummy else self.from_concatenate ), + imprecise_arg_kinds=( + imprecise_arg_kinds + if imprecise_arg_kinds is not _dummy + else self.imprecise_arg_kinds + ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, ) # Optimization: Only NewTypes are supported as subtypes since @@ -2129,6 +2149,7 @@ def serialize(self) -> JsonDict: "def_extras": dict(self.def_extras), "type_guard": self.type_guard.serialize() if self.type_guard is not None else None, "from_concatenate": self.from_concatenate, + "imprecise_arg_kinds": self.imprecise_arg_kinds, "unpack_kwargs": self.unpack_kwargs, } @@ -2152,6 +2173,7 @@ def deserialize(cls, data: JsonDict) -> CallableType: deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None ), from_concatenate=data["from_concatenate"], + imprecise_arg_kinds=data["imprecise_arg_kinds"], unpack_kwargs=data["unpack_kwargs"], ) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index d3273d6fb7da..fd1e1f0809ac 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -160,8 +160,9 @@ from typing import Callable, TypeVar from typing_extensions import ParamSpec P = ParamSpec('P') +R = TypeVar('R') -def f(x: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: +def f(x: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: return x(*args, **kwargs) def g(x: int, y: str) -> None: ... @@ -172,24 +173,6 @@ f(g, 1, y=1) # E: Argument "y" to "f" has incompatible type "int"; expected "st f(g) # E: Missing positional arguments "x", "y" in call to "f" [builtins fixtures/dict.pyi] -[case testParamSpecFunctionGeneric] -from typing import Callable, TypeVar -from typing_extensions import ParamSpec - -P = ParamSpec('P') -R = TypeVar('R') - -def f(x: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: - return x(*args, **kwargs) - -def g(x: int, y: str) -> None: ... - -reveal_type(f(g, 1, y='x')) # N: Revealed type is "None" -f(g, 'x', y='x') # E: Argument 1 to "f" has incompatible type "Callable[[int, str], None]"; expected "Callable[[str, str], None]" -f(g, 1, y=1) # E: Argument 1 to "f" has incompatible type "Callable[[int, str], None]"; expected "Callable[[int, int], None]" -f(g) # E: Argument 1 to "f" has incompatible type "Callable[[int, str], None]"; expected "Callable[[], None]" -[builtins fixtures/dict.pyi] - [case testParamSpecSpecialCase] from typing import Callable, TypeVar from typing_extensions import ParamSpec @@ -1637,3 +1620,15 @@ def test(x: int) -> int: ... apply(apply, test, x=42) # OK apply(apply, test, 42) # Also OK (but requires some special casing) [builtins fixtures/paramspec.pyi] + +[case testParamSpecApplyPosVsNamedOptional] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + +def apply(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: ... +def test(x: str = ..., y: int = ...) -> int: ... +apply(test, y=42) # OK +[builtins fixtures/paramspec.pyi] From 53382a7153f9fc3a246248ce0bb0600b7ffe8ff5 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Aug 2023 14:41:35 +0100 Subject: [PATCH 4/6] Undo not needed changes --- mypy/solve.py | 16 +++++----------- mypy/subtypes.py | 4 ---- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/mypy/solve.py b/mypy/solve.py index 42893a31a314..4b2b899c2a8d 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -85,7 +85,7 @@ def solve_constraints( continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] - solution = solve_one(originals[tv], lowers, uppers) + solution = solve_one(lowers, uppers) # Do not leak type variables in non-polymorphic solutions. if solution is None or not get_vars( @@ -163,17 +163,13 @@ def solve_with_dependent( solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: - res = solve_iteratively(flat_batch, graph, lowers, uppers, originals) + res = solve_iteratively(flat_batch, graph, lowers, uppers) solutions.update(res) return solutions, [free_solutions[tv] for tv in free_vars] def solve_iteratively( - batch: list[TypeVarId], - graph: Graph, - lowers: Bounds, - uppers: Bounds, - originals: dict[TypeVarId, TypeVarLikeType], + batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> Solutions: """Solve transitive closure sequentially, updating upper/lower bounds after each step. @@ -199,7 +195,7 @@ def solve_iteratively( break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) - result = solve_one(originals[solvable_tv], lowers[solvable_tv], uppers[solvable_tv]) + result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) solutions[solvable_tv] = result if result is None: # TODO: support backtracking lower/upper bound choices and order within SCCs. @@ -232,9 +228,7 @@ def solve_iteratively( return solutions -def solve_one( - type_var: TypeVarLikeType, lowers: Iterable[Type], uppers: Iterable[Type] -) -> Type | None: +def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" bottom: Type | None = None top: Type | None = None diff --git a/mypy/subtypes.py b/mypy/subtypes.py index abfa15913203..60fccc7e357c 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1485,10 +1485,6 @@ def are_parameters_compatible( if right.is_ellipsis_args: return True - if right.imprecise_arg_kinds: - allow_partial_overlap = True - ignore_pos_arg_names = True - left_star = left.var_arg() left_star2 = left.kw_arg() right_star = right.var_arg() From dfaa4f9fc1902ef335a7e3ed6818f3d3f14e77a8 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 17 Aug 2023 15:31:02 +0100 Subject: [PATCH 5/6] Simplify logic; try including return constraints --- mypy/constraints.py | 60 ++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 238b6cb36d04..0c83ea320d54 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -963,6 +963,14 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: res: list[Constraint] = [] cactual = self.actual.with_unpacked_kwargs() param_spec = template.param_spec() + + template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type + if template.type_guard is not None: + template_ret_type = template.type_guard + if cactual.type_guard is not None: + cactual_ret_type = cactual.type_guard + res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) + if param_spec is None: # TODO: Erase template variables if it is generic? if ( @@ -1043,38 +1051,31 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: continue res.extend(infer_constraints(t, a, neg_op(self.direction))) + param_spec_target: Type | None = None + skip_imprecise = ( + any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds + ) if not cactual_ps: max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]) prefix_len = min(prefix_len, max_prefix_len) # This logic matches top-level callable constraint exception, if we managed # to get other constraints for ParamSpec, don't infer one with imprecise kinds - if not ( - any(c.type_var == param_spec.id for c in res) - and cactual.imprecise_arg_kinds - ): - res.append( - Constraint( - param_spec, - neg_op(self.direction), - Parameters( - arg_types=cactual.arg_types[prefix_len:], - arg_kinds=cactual.arg_kinds[prefix_len:], - arg_names=cactual.arg_names[prefix_len:], - variables=cactual.variables - if not type_state.infer_polymorphic - else [], - imprecise_arg_kinds=cactual.imprecise_arg_kinds, - ), - ) + if not skip_imprecise: + param_spec_target = Parameters( + arg_types=cactual.arg_types[prefix_len:], + arg_kinds=cactual.arg_kinds[prefix_len:], + arg_names=cactual.arg_names[prefix_len:], + variables=cactual.variables + if not type_state.infer_polymorphic + else [], + imprecise_arg_kinds=cactual.imprecise_arg_kinds, ) else: - if len(param_spec.prefix.arg_types) <= len( - cactual_ps.prefix.arg_types - ) and not ( - any(c.type_var == param_spec.id for c in res) - and cactual_ps.prefix.imprecise_arg_kinds + if ( + len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types) + and not skip_imprecise ): - cactual_ps = cactual_ps.copy_modified( + param_spec_target = cactual_ps.copy_modified( prefix=Parameters( arg_types=cactual_ps.prefix.arg_types[prefix_len:], arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:], @@ -1082,15 +1083,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: imprecise_arg_kinds=cactual_ps.prefix.imprecise_arg_kinds, ) ) - res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps)) - - template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type - if template.type_guard is not None: - template_ret_type = template.type_guard - if cactual.type_guard is not None: - cactual_ret_type = cactual.type_guard - - res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) + if param_spec_target is not None: + res.append(Constraint(param_spec, neg_op(self.direction), param_spec_target)) if extra_tvars: for c in res: c.extra_tvars += cactual.variables From 90906619e2fe70911b679b01ade0e49aef7da11f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 25 Aug 2023 21:38:20 +0100 Subject: [PATCH 6/6] Add more tests --- test-data/unit/check-parameter-specification.test | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index b14244c7815f..6ef4a7e2463f 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -349,12 +349,16 @@ T = TypeVar('T') def register(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Callable[P, T]: ... def f(x: int) -> None: pass +def g(x: int, y: str) -> None: pass reveal_type(register(lambda: f(1))) # N: Revealed type is "def ()" reveal_type(register(lambda x: f(x), x=1)) # N: Revealed type is "def (x: Literal[1]?)" register(lambda x: f(x)) # E: Cannot infer type of lambda \ # E: Argument 1 to "register" has incompatible type "Callable[[Any], None]"; expected "Callable[[], None]" register(lambda x: f(x), y=1) # E: Argument 1 to "register" has incompatible type "Callable[[Arg(int, 'x')], None]"; expected "Callable[[Arg(int, 'y')], None]" +reveal_type(register(lambda x: f(x), 1)) # N: Revealed type is "def (Literal[1]?)" +reveal_type(register(lambda x, y: g(x, y), 1, "a")) # N: Revealed type is "def (Literal[1]?, Literal['a']?)" +reveal_type(register(lambda x, y: g(x, y), 1, y="a")) # N: Revealed type is "def (Literal[1]?, y: Literal['a']?)" [builtins fixtures/dict.pyi] [case testParamSpecInvalidCalls]