Skip to content

Commit

Permalink
Fix explicit type for partial (#17424)
Browse files Browse the repository at this point in the history
Fixes #17301
  • Loading branch information
ilevkivskyi authored Jun 23, 2024
1 parent abdaf6a commit 1b116df
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
17 changes: 14 additions & 3 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import mypy.checker
import mypy.plugin
import mypy.semanal
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.plugins.common import add_method_to_class
Expand All @@ -24,6 +25,8 @@

_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}

PARTIAL = "functools.partial"


class _MethodInfo(NamedTuple):
is_static: bool
Expand Down Expand Up @@ -142,7 +145,8 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
]
],
ret_type=ctx.api.named_generic_type(PARTIAL, [fn_type.ret_type]),
)
if defaulted.line < 0:
# Make up a line number if we don't have one
Expand Down Expand Up @@ -188,6 +192,13 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
return ctx.default_return_type
wrapped_ret_type = get_proper_type(bound.ret_type)
if not isinstance(wrapped_ret_type, Instance) or wrapped_ret_type.type.fullname != PARTIAL:
return ctx.default_return_type
if not mypy.semanal.refers_to_fullname(ctx.args[0][0], PARTIAL):
# If the first argument is partial, above call will trigger the plugin
# again, in between the wrapping above an unwrapping here.
bound = bound.copy_modified(ret_type=wrapped_ret_type.args[0])

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
Expand Down Expand Up @@ -237,7 +248,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
ret_type=ret_type,
)

ret = ctx.api.named_generic_type("functools.partial", [ret_type])
ret = ctx.api.named_generic_type(PARTIAL, [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
return ret

Expand All @@ -247,7 +258,7 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
if (
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
or not isinstance(ctx.type, Instance)
or ctx.type.type.fullname != "functools.partial"
or ctx.type.type.fullname != PARTIAL
or not ctx.type.extra_attrs
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
):
Expand Down
31 changes: 31 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,37 @@ reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \
# E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialExplicitType]
from functools import partial
from typing import Type, TypeVar, Callable

T = TypeVar("T")
def generic(string: str, integer: int, resulting_type: Type[T]) -> T: ...

p: partial[str] = partial(generic, resulting_type=str)
q: partial[bool] = partial(generic, resulting_type=str) # E: Argument "resulting_type" to "generic" has incompatible type "Type[str]"; expected "Type[bool]"

pc: Callable[..., str] = partial(generic, resulting_type=str)
qc: Callable[..., bool] = partial(generic, resulting_type=str) # E: Incompatible types in assignment (expression has type "partial[str]", variable has type "Callable[..., bool]") \
# N: "partial[str].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], str]"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialNestedPartial]
from functools import partial
from typing import Any

def foo(x: int) -> int: ...
p = partial(partial, foo)
reveal_type(p()(1)) # N: Revealed type is "builtins.int"
p()("no") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"

q = partial(partial, partial, foo)
q()()("no") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"

r = partial(partial, foo, 1)
reveal_type(r()()) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialTypeObject]
import functools
from typing import Type, Generic, TypeVar
Expand Down

0 comments on commit 1b116df

Please sign in to comment.