Skip to content

Commit

Permalink
Switch to making final variables context-sensitive
Browse files Browse the repository at this point in the history
This commit modifies this PR to make selecting the type of final
variables context-sensitive. Now, when we do:

    x: Final = 1

...the variable `x` is normally inferred to be of type `int`. However,
if that variable is used in a context which expects `Literal`, we infer
the literal type.

This commit also removes some of the hacks to mypy and the tests that
the first iteration added.
  • Loading branch information
Michael0x2a committed Jan 4, 2019
1 parent e5a7495 commit 7335991
Show file tree
Hide file tree
Showing 21 changed files with 378 additions and 142 deletions.
7 changes: 5 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
# Type checking pass number (0 = first pass)
pass_num = 0
# Last pass number to take
last_pass = DEFAULT_LAST_PASS # type: int
last_pass = DEFAULT_LAST_PASS
# Have we deferred the current function? If yes, don't infer additional
# types during this pass within the function.
current_node_deferred = False
Expand Down Expand Up @@ -1809,7 +1809,10 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
rvalue_type = self.expr_checker.accept(rvalue, infer_literal=inferred.is_final)
rvalue_type = self.expr_checker.accept(
rvalue,
in_final_declaration=inferred.is_final,
)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)

def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
Expand Down
73 changes: 39 additions & 34 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
StarType, is_optional, remove_optional, is_generic_instance
)
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self,
self.msg = msg
self.plugin = plugin
self.type_context = [None]
self.infer_literal = False
self.in_final_declaration = False
# Temporary overrides for expression types. This is currently
# used by the union math in overloads.
# TODO: refactor this to use a pattern similar to one in
Expand Down Expand Up @@ -211,10 +211,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:

def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
if self.is_literal_context() and var.name() in {'True', 'False'}:
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
else:
return var.type
if isinstance(var.type, Instance):
if self._is_literal_context() and var.type.final_value is not None:
return var.type.final_value
if var.name() in {'True', 'False'}:
return self._handle_literal_expr(var.name() == 'True', 'builtins.bool')
return var.type
else:
if not var.is_ready and self.chk.in_checked_function():
self.chk.handle_cannot_determine_type(var.name(), context)
Expand Down Expand Up @@ -693,7 +695,8 @@ def check_call(self,
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, False, self.msg,
original_type=callee, chk=self.chk)
original_type=callee, chk=self.chk,
in_literal_context=self._is_literal_context())
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
Expand Down Expand Up @@ -1757,7 +1760,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
original_type = self.accept(e.expr)
member_type = analyze_member_access(
e.name, original_type, e, is_lvalue, False, False,
self.msg, original_type=original_type, chk=self.chk)
self.msg, original_type=original_type, chk=self.chk,
in_literal_context=self._is_literal_context())
return member_type

def analyze_external_member_access(self, member: str, base_type: Type,
Expand All @@ -1767,35 +1771,36 @@ def analyze_external_member_access(self, member: str, base_type: Type,
"""
# TODO remove; no private definitions in mypy
return analyze_member_access(member, base_type, context, False, False, False,
self.msg, original_type=base_type, chk=self.chk)
self.msg, original_type=base_type, chk=self.chk,
in_literal_context=self._is_literal_context())

def _is_literal_context(self) -> bool:
return is_literal_type_like(self.type_context[-1])

def _handle_literal_expr(self, value: LiteralValue, fallback_name: str) -> Type:
typ = self.named_type(fallback_name)
if self._is_literal_context():
return LiteralType(value=value, fallback=typ)
elif self.in_final_declaration:
return typ.copy_with_final_value(value)
else:
return typ

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
typ = self.named_type('builtins.int')
if self.is_literal_context():
return LiteralType(value=e.value, fallback=typ)
return typ
return self._handle_literal_expr(e.value, 'builtins.int')

def visit_str_expr(self, e: StrExpr) -> Type:
"""Type check a string literal (trivial)."""
typ = self.named_type('builtins.str')
if self.is_literal_context():
return LiteralType(value=e.value, fallback=typ)
return typ
return self._handle_literal_expr(e.value, 'builtins.str')

def visit_bytes_expr(self, e: BytesExpr) -> Type:
"""Type check a bytes literal (trivial)."""
typ = self.named_type('builtins.bytes')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self._handle_literal_expr(e.value, 'builtins.bytes')

def visit_unicode_expr(self, e: UnicodeExpr) -> Type:
"""Type check a unicode literal (trivial)."""
typ = self.named_type('builtins.unicode')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ
return self._handle_literal_expr(e.value, 'builtins.unicode')

def visit_float_expr(self, e: FloatExpr) -> Type:
"""Type check a float literal (trivial)."""
Expand Down Expand Up @@ -1932,7 +1937,8 @@ def check_method_call_by_name(self,
"""
local_errors = local_errors or self.msg
method_type = analyze_member_access(method, base_type, context, False, False, True,
local_errors, original_type=base_type, chk=self.chk)
local_errors, original_type=base_type, chk=self.chk,
in_literal_context=self._is_literal_context())
return self.check_method_call(
method, base_type, method_type, args, arg_kinds, context, local_errors)

Expand Down Expand Up @@ -1996,6 +2002,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
context=context,
msg=local_errors,
chk=self.chk,
in_literal_context=self._is_literal_context()
)
if local_errors.is_errors():
return None
Expand Down Expand Up @@ -2946,7 +2953,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
override_info=base,
context=e,
msg=self.msg,
chk=self.chk)
chk=self.chk,
in_literal_context=self._is_literal_context())
assert False, 'unreachable'
else:
# Invalid super. This has been reported by the semantic analyzer.
Expand Down Expand Up @@ -3113,16 +3121,16 @@ def accept(self,
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
infer_literal: bool = False,
in_final_declaration: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
old_infer_literal = self.infer_literal
self.infer_literal = infer_literal
old_in_final_declaration = self.in_final_declaration
self.in_final_declaration = in_final_declaration
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand All @@ -3135,7 +3143,7 @@ def accept(self,
report_internal_error(err, self.chk.errors.file,
node.line, self.chk.errors, self.chk.options)
self.type_context.pop()
self.infer_literal = old_infer_literal
self.in_final_declaration = old_in_final_declaration
assert typ is not None
self.chk.store_type(node, typ)

Expand Down Expand Up @@ -3381,9 +3389,6 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
return ans
return known_type

def is_literal_context(self) -> bool:
return self.infer_literal or is_literal_type_like(self.type_context[-1])


def has_any_type(t: Type) -> bool:
"""Whether t contains an Any type"""
Expand Down
9 changes: 7 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def analyze_member_access(name: str,
msg: MessageBuilder, *,
original_type: Type,
chk: 'mypy.checker.TypeChecker',
override_info: Optional[TypeInfo] = None) -> Type:
override_info: Optional[TypeInfo] = None,
in_literal_context: bool = False) -> Type:
"""Return the type of attribute 'name' of 'typ'.
The actual implementation is in '_analyze_member_access' and this docstring
Expand All @@ -96,7 +97,11 @@ def analyze_member_access(name: str,
context,
msg,
chk=chk)
return _analyze_member_access(name, typ, mx, override_info)
result = _analyze_member_access(name, typ, mx, override_info)
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
return result.final_value
else:
return result


def _analyze_member_access(name: str,
Expand Down
4 changes: 2 additions & 2 deletions mypy/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
PYTHON2_VERSION = (2, 7) # type: Final
PYTHON3_VERSION = (3, 6) # type: Final
PYTHON3_VERSION_MIN = (3, 4) # type: Final
CACHE_DIR = '.mypy_cache' # type: Final[str]
CONFIG_FILE = 'mypy.ini' # type: Final[str]
CACHE_DIR = '.mypy_cache' # type: Final
CONFIG_FILE = 'mypy.ini' # type: Final
SHARED_CONFIG_FILES = ('setup.cfg',) # type: Final
USER_CONFIG_FILES = ('~/.mypy.ini',) # type: Final
CONFIG_FILES = (CONFIG_FILE,) + SHARED_CONFIG_FILES + USER_CONFIG_FILES # type: Final
Expand Down
2 changes: 1 addition & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
return t

def visit_instance(self, t: Instance) -> Type:
return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line)
return t.copy_modified(args=[AnyType(TypeOfAny.special_form)] * len(t.args))

def visit_type_var(self, t: TypeVarType) -> Type:
return AnyType(TypeOfAny.special_form)
Expand Down
6 changes: 2 additions & 4 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,14 @@ def visit_erased_type(self, t: ErasedType) -> Type:
raise RuntimeError()

def visit_instance(self, t: Instance) -> Type:
args = self.expand_types(t.args)
return Instance(t.type, args, t.line, t.column)
return t.copy_modified(args=self.expand_types(t.args))

def visit_type_var(self, t: TypeVarType) -> Type:
repl = self.variables.get(t.id, t)
if isinstance(repl, Instance):
inst = repl
# Return copy of instance with type erasure flag on.
return Instance(inst.type, inst.args, line=inst.line,
column=inst.column, erased=True)
return inst.copy_modified(erased=True)
else:
return repl

Expand Down
2 changes: 1 addition & 1 deletion mypy/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def infer_condition_value(expr: Expression, options: Options) -> int:
if alias.op == 'not':
expr = alias.expr
negated = True
result = TRUTH_VALUE_UNKNOWN # type: int
result = TRUTH_VALUE_UNKNOWN
if isinstance(expr, NameExpr):
name = expr.name
elif isinstance(expr, MemberExpr):
Expand Down
3 changes: 2 additions & 1 deletion mypy/sametypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def visit_deleted_type(self, left: DeletedType) -> bool:
def visit_instance(self, left: Instance) -> bool:
return (isinstance(self.right, Instance) and
left.type == self.right.type and
is_same_types(left.args, self.right.args))
is_same_types(left.args, self.right.args) and
left.final_value == self.right.final_value)

def visit_type_var(self, left: TypeVarType) -> bool:
return (isinstance(self.right, TypeVarType) and
Expand Down
21 changes: 14 additions & 7 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from mypy.messages import CANNOT_ASSIGN_TO_TYPE, MessageBuilder
from mypy.types import (
FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type,
CallableType, Overloaded, Instance, Type, AnyType, LiteralType,
CallableType, Overloaded, Instance, Type, AnyType,
TypeTranslator, TypeOfAny, TypeType, NoneTyp,
)
from mypy.nodes import implicit_module_attrs
Expand Down Expand Up @@ -1908,22 +1908,29 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Opt
# inside type variables with value restrictions (like
# AnyStr).
return None
if isinstance(rvalue, FloatExpr):
return self.named_type_or_none('builtins.float')

if isinstance(rvalue, IntExpr):
typ = self.named_type_or_none('builtins.int')
if typ and is_final:
return LiteralType(rvalue.value, typ, rvalue.line, rvalue.column)
return typ.copy_with_final_value(rvalue.value)
return typ
if isinstance(rvalue, FloatExpr):
return self.named_type_or_none('builtins.float')
if isinstance(rvalue, StrExpr):
typ = self.named_type_or_none('builtins.str')
if typ and is_final:
return LiteralType(rvalue.value, typ, rvalue.line, rvalue.column)
return typ.copy_with_final_value(rvalue.value)
return typ
if isinstance(rvalue, BytesExpr):
return self.named_type_or_none('builtins.bytes')
typ = self.named_type_or_none('builtins.bytes')
if typ and is_final:
return typ.copy_with_final_value(rvalue.value)
return typ
if isinstance(rvalue, UnicodeExpr):
return self.named_type_or_none('builtins.unicode')
typ = self.named_type_or_none('builtins.unicode')
if typ and is_final:
return typ.copy_with_final_value(rvalue.value)
return typ

return None

Expand Down
3 changes: 2 additions & 1 deletion mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
def visit_instance(self, typ: Instance) -> SnapshotItem:
return ('Instance',
typ.type.fullname(),
snapshot_types(typ.args))
snapshot_types(typ.args),
None if typ.final_value is None else snapshot_type(typ.final_value))

def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
return ('TypeVar',
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def visit_instance(self, typ: Instance) -> None:
typ.type = self.fixup(typ.type)
for arg in typ.args:
arg.accept(self)
if typ.final_value:
typ.final_value.accept(self)

def visit_any(self, typ: AnyType) -> None:
pass
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
triggers = [trigger]
for arg in typ.args:
triggers.extend(self.get_type_triggers(arg))
if typ.final_value:
triggers.extend(self.get_type_triggers(typ.final_value))
return triggers

def visit_any(self, typ: AnyType) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int],
self._import_lines = [] # type: List[str]
self._indent = ''
self._vars = [[]] # type: List[List[str]]
self._state = EMPTY # type: str
self._state = EMPTY
self._toplevel_names = [] # type: List[str]
self._pyversion = pyversion
self._include_private = include_private
Expand Down
2 changes: 1 addition & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
return {IS_CLASS_OR_STATIC}
# just a variable
if isinstance(v, Var) and not v.is_property:
flags = {IS_SETTABLE} # type: Set[int]
flags = {IS_SETTABLE}
if v.is_classvar:
flags.add(IS_CLASSVAR)
return flags
Expand Down
2 changes: 1 addition & 1 deletion mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
return t

def visit_instance(self, t: Instance) -> Type:
return Instance(t.type, self.translate_types(t.args), t.line, t.column)
return t.copy_modified(args=self.translate_types(t.args))

def visit_type_var(self, t: TypeVarType) -> Type:
return t
Expand Down
5 changes: 4 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
elif isinstance(arg, (NoneTyp, LiteralType)):
# Types that we can just add directly to the literal/potential union of literals.
return [arg]
elif isinstance(arg, Instance) and arg.final_value is not None:
# Types generated from declarations like "var: Final = 4".
return [arg.final_value]
elif isinstance(arg, UnionType):
out = []
for union_arg in arg.items:
Expand Down Expand Up @@ -1073,7 +1076,7 @@ def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type],
def set_any_tvars(tp: Type, vars: List[str],
newline: int, newcolumn: int, implicit: bool = True) -> Type:
if implicit:
type_of_any = TypeOfAny.from_omitted_generics # type: int
type_of_any = TypeOfAny.from_omitted_generics
else:
type_of_any = TypeOfAny.special_form
any_type = AnyType(type_of_any, line=newline, column=newcolumn)
Expand Down
Loading

0 comments on commit 7335991

Please sign in to comment.