Skip to content

Commit

Permalink
Improve for loop index variable type narrowing (#18014)
Browse files Browse the repository at this point in the history
Preserve the literal type of index expressions a bit longer (until the
next assignment) to support TypedDict lookups.

```py
from typing import TypedDict

class X(TypedDict):
    hourly: int
    daily: int

def func(x: X) -> None:
    for var in ("hourly", "daily"):
        print(x[var])
```

Closes #9230
  • Loading branch information
cdce8p authored Oct 23, 2024
1 parent 60d1b37 commit 9e68959
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 4 deletions.
9 changes: 9 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3175,6 +3175,14 @@ def check_assignment(
# Don't use type binder for definitions of special forms, like named tuples.
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)
if (
isinstance(lvalue, NameExpr)
and isinstance(lvalue.node, Var)
and lvalue.node.is_inferred
and lvalue.node.is_index_var
and lvalue_type is not None
):
lvalue.node.type = remove_instance_last_known_values(lvalue_type)

elif index_lvalue:
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
Expand All @@ -3184,6 +3192,7 @@ def check_assignment(
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
if not (
inferred.is_final
or inferred.is_index_var
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
):
rvalue_type = remove_instance_last_known_values(rvalue_type)
Expand Down
3 changes: 3 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ def is_dynamic(self) -> bool:
"is_classvar",
"is_abstract_var",
"is_final",
"is_index_var",
"final_unset_in_class",
"final_set_in_init",
"explicit_self_type",
Expand Down Expand Up @@ -1005,6 +1006,7 @@ class Var(SymbolNode):
"is_classvar",
"is_abstract_var",
"is_final",
"is_index_var",
"final_unset_in_class",
"final_set_in_init",
"is_suppressed_import",
Expand Down Expand Up @@ -1039,6 +1041,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
self.is_settable_property = False
self.is_classvar = False
self.is_abstract_var = False
self.is_index_var = False
# Set to true when this variable refers to a module we were unable to
# parse for some reason (eg a silenced module)
self.is_suppressed_import = False
Expand Down
18 changes: 15 additions & 3 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4225,6 +4225,7 @@ def analyze_lvalue(
is_final: bool = False,
escape_comprehensions: bool = False,
has_explicit_value: bool = False,
is_index_var: bool = False,
) -> None:
"""Analyze an lvalue or assignment target.
Expand All @@ -4235,6 +4236,7 @@ def analyze_lvalue(
escape_comprehensions: If we are inside a comprehension, set the variable
in the enclosing scope instead. This implements
https://www.python.org/dev/peps/pep-0572/#scope-of-the-target
is_index_var: If lval is the index variable in a for loop
"""
if escape_comprehensions:
assert isinstance(lval, NameExpr), "assignment expression target must be NameExpr"
Expand All @@ -4245,6 +4247,7 @@ def analyze_lvalue(
is_final,
escape_comprehensions,
has_explicit_value=has_explicit_value,
is_index_var=is_index_var,
)
elif isinstance(lval, MemberExpr):
self.analyze_member_lvalue(lval, explicit_type, is_final, has_explicit_value)
Expand All @@ -4271,6 +4274,7 @@ def analyze_name_lvalue(
is_final: bool,
escape_comprehensions: bool,
has_explicit_value: bool,
is_index_var: bool,
) -> None:
"""Analyze an lvalue that targets a name expression.
Expand Down Expand Up @@ -4309,7 +4313,9 @@ def analyze_name_lvalue(

if (not existing or isinstance(existing.node, PlaceholderNode)) and not outer:
# Define new variable.
var = self.make_name_lvalue_var(lvalue, kind, not explicit_type, has_explicit_value)
var = self.make_name_lvalue_var(
lvalue, kind, not explicit_type, has_explicit_value, is_index_var
)
added = self.add_symbol(name, var, lvalue, escape_comprehensions=escape_comprehensions)
# Only bind expression if we successfully added name to symbol table.
if added:
Expand Down Expand Up @@ -4361,7 +4367,12 @@ def is_alias_for_final_name(self, name: str) -> bool:
return existing is not None and is_final_node(existing.node)

def make_name_lvalue_var(
self, lvalue: NameExpr, kind: int, inferred: bool, has_explicit_value: bool
self,
lvalue: NameExpr,
kind: int,
inferred: bool,
has_explicit_value: bool,
is_index_var: bool,
) -> Var:
"""Return a Var node for an lvalue that is a name expression."""
name = lvalue.name
Expand All @@ -4380,6 +4391,7 @@ def make_name_lvalue_var(
v._fullname = name
v.is_ready = False # Type not inferred yet
v.has_explicit_value = has_explicit_value
v.is_index_var = is_index_var
return v

def make_name_lvalue_point_to_existing_def(
Expand Down Expand Up @@ -5290,7 +5302,7 @@ def visit_for_stmt(self, s: ForStmt) -> None:
s.expr.accept(self)

# Bind index variables and check if they define new names.
self.analyze_lvalue(s.index, explicit_type=s.index_type is not None)
self.analyze_lvalue(s.index, explicit_type=s.index_type is not None, is_index_var=True)
if s.index_type:
if self.is_classvar(s.index_type):
self.fail_invalid_classvar(s.index)
Expand Down
29 changes: 29 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,35 @@ class B: pass
[builtins fixtures/for.pyi]
[out]

[case testForStatementIndexNarrowing]
from typing_extensions import TypedDict

class X(TypedDict):
hourly: int
daily: int

x: X
for a in ("hourly", "daily"):
reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]"
reveal_type(x[a]) # N: Revealed type is "builtins.int"
reveal_type(a.upper()) # N: Revealed type is "builtins.str"
c = a
reveal_type(c) # N: Revealed type is "builtins.str"
a = "monthly"
reveal_type(a) # N: Revealed type is "builtins.str"
a = "yearly"
reveal_type(a) # N: Revealed type is "builtins.str"
a = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
reveal_type(a) # N: Revealed type is "builtins.str"
d = a
reveal_type(d) # N: Revealed type is "builtins.str"

b: str
for b in ("hourly", "daily"):
reveal_type(b) # N: Revealed type is "builtins.str"
reveal_type(b.upper()) # N: Revealed type is "builtins.str"
[builtins fixtures/for.pyi]


-- Regression tests
-- ----------------
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/fixtures/for.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ class type: pass
class tuple(Generic[t]):
def __iter__(self) -> Iterator[t]: pass
class function: pass
class ellipsis: pass
class bool: pass
class int: pass # for convenience
class str: pass # for convenience
class str: # for convenience
def upper(self) -> str: ...

class list(Iterable[t], Generic[t]):
def __iter__(self) -> Iterator[t]: pass
Expand Down

0 comments on commit 9e68959

Please sign in to comment.