Skip to content

Commit

Permalink
Add interactions between Literal and Final (#6081)
Browse files Browse the repository at this point in the history
This pull request adds logic to handle interactions between Literal and
Final. In short, if the user were to define a variable like `x: Final = 3` and
latter do `some_func(x)`, mypy will attempt to type-check the code almost
as if the user had done `some_func(3)` instead.

This normally does not make a difference, except when type-checking code
using literal types. For example, if `some_func` accepts a `Literal[3]` up
above, the code would type-check since `x` cannot be anything other
then a `3`.

Or to put it another way, this pull request makes variables that use `Final`
with the type omitted context-sensitive.
  • Loading branch information
Michael0x2a authored Jan 8, 2019
1 parent fd048ab commit 94fe11c
Show file tree
Hide file tree
Showing 15 changed files with 661 additions and 69 deletions.
7 changes: 5 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,8 +1810,11 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
self.infer_variable_type(inferred, lvalue, self.expr_checker.accept(rvalue),
rvalue)
rvalue_type = self.expr_checker.accept(
rvalue,
in_final_declaration=inferred.is_final,
)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)

def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
rvalue: Expression) -> bool:
Expand Down
95 changes: 69 additions & 26 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
StarType, is_optional, remove_optional, is_generic_instance
)
Expand Down Expand Up @@ -139,6 +139,16 @@ def __init__(self,
self.msg = msg
self.plugin = plugin
self.type_context = [None]

# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
# For example, if we're checking the "3" in a statement like "var: Final = 3".
#
# This flag changes the type that eventually gets inferred for "var". Instead of
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
# of the underlying literal value. See the comments in Instance's constructors for
# more details.
self.in_final_declaration = False

# Temporary overrides for expression types. This is currently
# used by the union math in overloads.
# TODO: refactor this to use a pattern similar to one in
Expand Down Expand Up @@ -210,10 +220,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:

def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
else:
return var.type
if isinstance(var.type, Instance):
if self.is_literal_context() and var.type.final_value is not None:
return var.type.final_value
if var.name() in {'True', 'False'}:
return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool')
return var.type
else:
if not var.is_ready and self.chk.in_checked_function():
self.chk.handle_cannot_determine_type(var.name(), context)
Expand Down Expand Up @@ -691,7 +703,8 @@ def check_call(self,
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, False, self.msg,
original_type=callee, chk=self.chk)
original_type=callee, chk=self.chk,
in_literal_context=self.is_literal_context())
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
Expand Down Expand Up @@ -1755,7 +1768,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
original_type = self.accept(e.expr)
member_type = analyze_member_access(
e.name, original_type, e, is_lvalue, False, False,
self.msg, original_type=original_type, chk=self.chk)
self.msg, original_type=original_type, chk=self.chk,
in_literal_context=self.is_literal_context())
return member_type

def analyze_external_member_access(self, member: str, base_type: Type,
Expand All @@ -1765,35 +1779,57 @@ def analyze_external_member_access(self, member: str, base_type: Type,
"""
# TODO remove; no private definitions in mypy
return analyze_member_access(member, base_type, context, False, False, False,
self.msg, original_type=base_type, chk=self.chk)
self.msg, original_type=base_type, chk=self.chk,
in_literal_context=self.is_literal_context())

def is_literal_context(self) -> bool:
return is_literal_type_like(self.type_context[-1])

def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Type:
"""Analyzes the given literal expression and determines if we should be
inferring an Instance type, a Literal[...] type, or an Instance that
remembers the original literal. We...
1. ...Infer a normal Instance in most circumstances.
2. ...Infer a Literal[...] if we're in a literal context. For example, if we
were analyzing the "3" in "foo(3)" where "foo" has a signature of
"def foo(Literal[3]) -> None", we'd want to infer that the "3" has a
type of Literal[3] instead of Instance.
3. ...Infer an Instance that remembers the original Literal if we're declaring
a Final variable with an inferred type -- for example, "bar" in "bar: Final = 3"
would be assigned an Instance that remembers it originated from a '3'. See
the comments in Instance's constructor for more details.
"""
typ = self.named_type(fallback_name)
if self.is_literal_context():
return LiteralType(value=value, fallback=typ)
elif self.in_final_declaration:
return typ.copy_modified(final_value=LiteralType(
value=value,
fallback=typ,
line=typ.line,
column=typ.column,
))
else:
return typ

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
typ = self.named_type('builtins.int')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.int')

def visit_str_expr(self, e: StrExpr) -> Type:
"""Type check a string literal (trivial)."""
typ = self.named_type('builtins.str')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.str')

def visit_bytes_expr(self, e: BytesExpr) -> Type:
"""Type check a bytes literal (trivial)."""
typ = self.named_type('builtins.bytes')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.bytes')

def visit_unicode_expr(self, e: UnicodeExpr) -> Type:
"""Type check a unicode literal (trivial)."""
typ = self.named_type('builtins.unicode')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self.infer_literal_expr_type(e.value, 'builtins.unicode')

def visit_float_expr(self, e: FloatExpr) -> Type:
"""Type check a float literal (trivial)."""
Expand Down Expand Up @@ -1930,7 +1966,8 @@ def check_method_call_by_name(self,
"""
local_errors = local_errors or self.msg
method_type = analyze_member_access(method, base_type, context, False, False, True,
local_errors, original_type=base_type, chk=self.chk)
local_errors, original_type=base_type, chk=self.chk,
in_literal_context=self.is_literal_context())
return self.check_method_call(
method, base_type, method_type, args, arg_kinds, context, local_errors)

Expand Down Expand Up @@ -1994,6 +2031,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
context=context,
msg=local_errors,
chk=self.chk,
in_literal_context=self.is_literal_context()
)
if local_errors.is_errors():
return None
Expand Down Expand Up @@ -2950,7 +2988,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
override_info=base,
context=e,
msg=self.msg,
chk=self.chk)
chk=self.chk,
in_literal_context=self.is_literal_context())
assert False, 'unreachable'
else:
# Invalid super. This has been reported by the semantic analyzer.
Expand Down Expand Up @@ -3117,13 +3156,16 @@ def accept(self,
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
in_final_declaration: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
old_in_final_declaration = self.in_final_declaration
self.in_final_declaration = in_final_declaration
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand All @@ -3136,6 +3178,7 @@ def accept(self,
report_internal_error(err, self.chk.errors.file,
node.line, self.chk.errors, self.chk.options)
self.type_context.pop()
self.in_final_declaration = old_in_final_declaration
assert typ is not None
self.chk.store_type(node, typ)

Expand Down
9 changes: 7 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def analyze_member_access(name: str,
msg: MessageBuilder, *,
original_type: Type,
chk: 'mypy.checker.TypeChecker',
override_info: Optional[TypeInfo] = None) -> Type:
override_info: Optional[TypeInfo] = None,
in_literal_context: bool = False) -> Type:
"""Return the type of attribute 'name' of 'typ'.
The actual implementation is in '_analyze_member_access' and this docstring
Expand All @@ -96,7 +97,11 @@ def analyze_member_access(name: str,
context,
msg,
chk=chk)
return _analyze_member_access(name, typ, mx, override_info)
result = _analyze_member_access(name, typ, mx, override_info)
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
return result.final_value
else:
return result


def _analyze_member_access(name: str,
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def visit_instance(self, inst: Instance) -> None:
base.accept(self)
for a in inst.args:
a.accept(self)
if inst.final_value is not None:
inst.final_value.accept(self)

def visit_any(self, o: Any) -> None:
pass # Nothing to descend into.
Expand Down
3 changes: 2 additions & 1 deletion mypy/sametypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def visit_deleted_type(self, left: DeletedType) -> bool:
def visit_instance(self, left: Instance) -> bool:
return (isinstance(self.right, Instance) and
left.type == self.right.type and
is_same_types(left.args, self.right.args))
is_same_types(left.args, self.right.args) and
left.final_value == self.right.final_value)

def visit_type_var(self, left: TypeVarType) -> bool:
return (isinstance(self.right, TypeVarType) and
Expand Down
37 changes: 27 additions & 10 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from mypy.messages import CANNOT_ASSIGN_TO_TYPE, MessageBuilder
from mypy.types import (
FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type,
CallableType, Overloaded, Instance, Type, AnyType,
CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue,
TypeTranslator, TypeOfAny, TypeType, NoneTyp,
)
from mypy.nodes import implicit_module_attrs
Expand Down Expand Up @@ -1760,9 +1760,9 @@ def final_cb(keep_final: bool) -> None:
self.type and self.type.is_protocol and not self.is_func_scope()):
self.fail('All protocol members must have explicitly declared types', s)
# Set the type if the rvalue is a simple literal (even if the above error occurred).
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr):
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr):
if s.lvalues[0].is_inferred_def:
s.type = self.analyze_simple_literal_type(s.rvalue)
s.type = self.analyze_simple_literal_type(s.rvalue, s.is_final_def)
if s.type:
# Store type into nodes.
for lvalue in s.lvalues:
Expand Down Expand Up @@ -1900,8 +1900,10 @@ def unbox_literal(self, e: Expression) -> Optional[Union[int, float, bool, str]]
return True if e.name == 'True' else False
return None

def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
"""Return builtins.int if rvalue is an int literal, etc."""
def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Optional[Type]:
"""Return builtins.int if rvalue is an int literal, etc.
If this is a 'Final' context, we return "Literal[...]" instead."""
if self.options.semantic_analysis_only or self.function_stack:
# Skip this if we're only doing the semantic analysis pass.
# This is mostly to avoid breaking unit tests.
Expand All @@ -1910,16 +1912,31 @@ def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
# inside type variables with value restrictions (like
# AnyStr).
return None
if isinstance(rvalue, IntExpr):
return self.named_type_or_none('builtins.int')
if isinstance(rvalue, FloatExpr):
return self.named_type_or_none('builtins.float')

value = None # type: LiteralValue
type_name = None # type: Optional[str]
if isinstance(rvalue, IntExpr):
value, type_name = rvalue.value, 'builtins.int'
if isinstance(rvalue, StrExpr):
return self.named_type_or_none('builtins.str')
value, type_name = rvalue.value, 'builtins.str'
if isinstance(rvalue, BytesExpr):
return self.named_type_or_none('builtins.bytes')
value, type_name = rvalue.value, 'builtins.bytes'
if isinstance(rvalue, UnicodeExpr):
return self.named_type_or_none('builtins.unicode')
value, type_name = rvalue.value, 'builtins.unicode'

if type_name is not None:
typ = self.named_type_or_none(type_name)
if typ and is_final:
return typ.copy_modified(final_value=LiteralType(
value=value,
fallback=typ,
line=typ.line,
column=typ.column,
))
return typ

return None

def analyze_alias(self, rvalue: Expression) -> Tuple[Optional[Type], List[str],
Expand Down
3 changes: 2 additions & 1 deletion mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
def visit_instance(self, typ: Instance) -> SnapshotItem:
return ('Instance',
typ.type.fullname(),
snapshot_types(typ.args))
snapshot_types(typ.args),
None if typ.final_value is None else snapshot_type(typ.final_value))

def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
return ('TypeVar',
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def visit_instance(self, typ: Instance) -> None:
typ.type = self.fixup(typ.type)
for arg in typ.args:
arg.accept(self)
if typ.final_value:
typ.final_value.accept(self)

def visit_any(self, typ: AnyType) -> None:
pass
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
triggers = [trigger]
for arg in typ.args:
triggers.extend(self.get_type_triggers(arg))
if typ.final_value:
triggers.extend(self.get_type_triggers(typ.final_value))
return triggers

def visit_any(self, typ: AnyType) -> List[str]:
Expand Down
15 changes: 13 additions & 2 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from abc import abstractmethod
from collections import OrderedDict
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional
from mypy_extensions import trait

T = TypeVar('T')
Expand Down Expand Up @@ -159,7 +159,18 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
return t

def visit_instance(self, t: Instance) -> Type:
return Instance(t.type, self.translate_types(t.args), t.line, t.column)
final_value = None # type: Optional[LiteralType]
if t.final_value is not None:
raw_final_value = t.final_value.accept(self)
assert isinstance(raw_final_value, LiteralType)
final_value = raw_final_value
return Instance(
typ=t.type,
args=self.translate_types(t.args),
line=t.line,
column=t.column,
final_value=final_value,
)

def visit_type_var(self, t: TypeVarType) -> Type:
return t
Expand Down
3 changes: 3 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
elif isinstance(arg, (NoneTyp, LiteralType)):
# Types that we can just add directly to the literal/potential union of literals.
return [arg]
elif isinstance(arg, Instance) and arg.final_value is not None:
# Types generated from declarations like "var: Final = 4".
return [arg.final_value]
elif isinstance(arg, UnionType):
out = []
for union_arg in arg.items:
Expand Down
Loading

0 comments on commit 94fe11c

Please sign in to comment.