Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More helpful type guards #14238

Merged
merged 9 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5284,10 +5284,26 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
# TODO: Follow *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
# the first argument might be used as a kwarg
called_type = get_proper_type(self.lookup_type(node.callee))
assert isinstance(called_type, (CallableType, Overloaded))

# *assuming* the overloaded function is correct, there's a couple cases:
# 1) The first argument has different names, but is pos-only. We don't
# care about this case, the argument must be passed positionally.
# 2) The first argument allows keyword reference, therefore must be the
# same between overloads.
name = called_type.items[0].arg_names[0]

if name in node.arg_names:
idx = node.arg_names.index(name)
# we want the idx-th variable to be narrowed
expr = collapse_walrus(node.args[idx])
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
else:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
if literal(expr) == LITERAL_TYPE:
# Note: we wrap the target type, so that we can special case later.
# Namely, for isinstance() we use a normal meet, while TypeGuard is
Expand Down
14 changes: 14 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,20 @@ def analyze_func_def(self, defn: FuncDef) -> None:
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
# type guards need to have a positional argument, to spec
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
if (
result.type_guard
and ARG_POS not in result.arg_kinds[self.is_class_scope() :]
and not defn.is_static
):
self.fail(
"TypeGuard functions must have a positional argument",
result,
code=codes.VALID_TYPE,
)
# in this case, we just kind of just ... remove the type guard.
result = result.copy_modified(type_guard=None)

result = self.remove_unpack_kwargs(defn, result)
if has_self_type and self.type is not None:
info = self.type
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,31 @@ class C(Generic[T]):
[out]
main:10: note: Revealed type is "builtins.int"
main:10: note: Revealed type is "builtins.str"

[case testTypeGuardWithPositionalOnlyArg]
# flags: --python-version 3.8
from typing_extensions import TypeGuard

def typeguard(x: object, /) -> TypeGuard[int]:
...

n: object
if typeguard(n):
reveal_type(n)
[builtins fixtures/tuple.pyi]
[out]
main:9: note: Revealed type is "builtins.int"

[case testTypeGuardKeywordFollowingWalrus]
# flags: --python-version 3.8
from typing import cast
from typing_extensions import TypeGuard

def typeguard(x: object) -> TypeGuard[int]:
...

if typeguard(x=(n := cast(object, "hi"))):
reveal_type(n)
[builtins fixtures/tuple.pyi]
[out]
main:9: note: Revealed type is "builtins.int"
88 changes: 81 additions & 7 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ reveal_type(foo) # N: Revealed type is "def (a: builtins.object) -> TypeGuard[b
[case testTypeGuardCallArgsNone]
from typing_extensions import TypeGuard
class Point: pass
# TODO: error on the 'def' line (insufficient args for type guard)
def is_point() -> TypeGuard[Point]: pass

def is_point() -> TypeGuard[Point]: pass # E: TypeGuard functions must have a positional argument
def main(a: object) -> None:
if is_point():
reveal_type(a) # N: Revealed type is "builtins.object"
Expand Down Expand Up @@ -227,13 +227,13 @@ def main(a: object) -> None:
from typing_extensions import TypeGuard
def is_float(a: object, b: object = 0) -> TypeGuard[float]: pass
def main1(a: object) -> None:
# This is debatable -- should we support these cases?
if is_float(a=a, b=1):
reveal_type(a) # N: Revealed type is "builtins.float"

if is_float(a=a, b=1): # E: Type guard requires positional argument
reveal_type(a) # N: Revealed type is "builtins.object"
if is_float(b=1, a=a):
reveal_type(a) # N: Revealed type is "builtins.float"

if is_float(b=1, a=a): # E: Type guard requires positional argument
reveal_type(a) # N: Revealed type is "builtins.object"
# This is debatable -- should we support these cases?

ta = (a,)
if is_float(*ta): # E: Type guard requires positional argument
Expand Down Expand Up @@ -597,3 +597,77 @@ def func(names: Tuple[str, ...]):
if is_two_element_tuple(names):
reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeGuardErroneousDefinitionFails]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, *, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...

def bad_typeguard(*, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithKeywordArg]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, x: object) -> TypeGuard[int]:
...

def typeguard(x: object) -> TypeGuard[int]:
...

n: object
if typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"

if Z().typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testStaticMethodTypeGuard]
from typing_extensions import TypeGuard

class Y:
@staticmethod
def typeguard(h: object) -> TypeGuard[int]:
...

x: object
if Y().typeguard(x):
reveal_type(x) # N: Revealed type is "builtins.int"
if Y.typeguard(x):
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
[builtins fixtures/classmethod.pyi]

[case testTypeGuardKwargFollowingThroughOverloaded]
from typing import overload, Union
from typing_extensions import TypeGuard

@overload
def typeguard(x: object, y: str) -> TypeGuard[str]:
...

@overload
def typeguard(x: object, y: int) -> TypeGuard[int]:
...

def typeguard(x: object, y: Union[int, str]) -> Union[TypeGuard[int], TypeGuard[str]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this return type makes sense. But can add an error in another PR

...

x: object
if typeguard(x=x, y=42):
reveal_type(x) # N: Revealed type is "builtins.int"

if typeguard(y=42, x=x):
reveal_type(x) # N: Revealed type is "builtins.int"

if typeguard(x=x, y="42"):
reveal_type(x) # N: Revealed type is "builtins.str"

if typeguard(y="42", x=x):
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]