Skip to content

Commit

Permalink
Use context manager for Scope (#11053)
Browse files Browse the repository at this point in the history
Related issue: #1184
Follows up to #10569, #10685

This PR:
* Refactors `Scope`
  - Replaces `Scope.enter_file` with `Scope.module_scope`
  - Replaces `Scope.enter_function` with `Scope.function_scope`
  - Splits `Scope.leave()` into corresponding context manager
* Deletes unused files
  • Loading branch information
97littleleaf11 authored Sep 6, 2021
1 parent b17ae30 commit 725a24a
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 185 deletions.
84 changes: 40 additions & 44 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,29 +297,26 @@ def check_first_pass(self) -> None:
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
self.errors.set_file(self.path, self.tree.fullname, scope=self.tscope)
self.tscope.enter_file(self.tree.fullname)
with self.enter_partial_types():
with self.binder.top_frame_context():
with self.tscope.module_scope(self.tree.fullname):
with self.enter_partial_types(), self.binder.top_frame_context():
for d in self.tree.defs:
self.accept(d)

assert not self.current_node_deferred
assert not self.current_node_deferred

all_ = self.globals.get('__all__')
if all_ is not None and all_.type is not None:
all_node = all_.node
assert all_node is not None
seq_str = self.named_generic_type('typing.Sequence',
[self.named_type('builtins.str')])
if self.options.python_version[0] < 3:
all_ = self.globals.get('__all__')
if all_ is not None and all_.type is not None:
all_node = all_.node
assert all_node is not None
seq_str = self.named_generic_type('typing.Sequence',
[self.named_type('builtins.unicode')])
if not is_subtype(all_.type, seq_str):
str_seq_s, all_s = format_type_distinctly(seq_str, all_.type)
self.fail(message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s),
all_node)

self.tscope.leave()
[self.named_type('builtins.str')])
if self.options.python_version[0] < 3:
seq_str = self.named_generic_type('typing.Sequence',
[self.named_type('builtins.unicode')])
if not is_subtype(all_.type, seq_str):
str_seq_s, all_s = format_type_distinctly(seq_str, all_.type)
self.fail(message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s),
all_node)

def check_second_pass(self,
todo: Optional[Sequence[Union[DeferredNode,
Expand All @@ -334,25 +331,24 @@ def check_second_pass(self,
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(self.path, self.tree.fullname, scope=self.tscope)
self.tscope.enter_file(self.tree.fullname)
self.pass_num += 1
if not todo:
todo = self.deferred_nodes
else:
assert not self.deferred_nodes
self.deferred_nodes = []
done: Set[Union[DeferredNodeType, FineGrainedDeferredNodeType]] = set()
for node, active_typeinfo in todo:
if node in done:
continue
# This is useful for debugging:
# print("XXX in pass %d, class %s, function %s" %
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with self.tscope.class_scope(active_typeinfo) if active_typeinfo else nothing():
with self.scope.push_class(active_typeinfo) if active_typeinfo else nothing():
self.check_partial(node)
self.tscope.leave()
with self.tscope.module_scope(self.tree.fullname):
self.pass_num += 1
if not todo:
todo = self.deferred_nodes
else:
assert not self.deferred_nodes
self.deferred_nodes = []
done: Set[Union[DeferredNodeType, FineGrainedDeferredNodeType]] = set()
for node, active_typeinfo in todo:
if node in done:
continue
# This is useful for debugging:
# print("XXX in pass %d, class %s, function %s" %
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with self.tscope.class_scope(active_typeinfo) if active_typeinfo else nothing():
with self.scope.push_class(active_typeinfo) if active_typeinfo else nothing():
self.check_partial(node)
return True

def check_partial(self, node: Union[DeferredNodeType, FineGrainedDeferredNodeType]) -> None:
Expand Down Expand Up @@ -874,7 +870,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
if isinstance(typ.ret_type, TypeVarType):
if typ.ret_type.variance == CONTRAVARIANT:
self.fail(message_registry.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT,
typ.ret_type)
typ.ret_type)

# Check that Generator functions have the appropriate return type.
if defn.is_generator:
Expand Down Expand Up @@ -992,7 +988,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
self.accept(item.body)
unreachable = self.binder.is_unreachable()

if (self.options.warn_no_return and not unreachable):
if self.options.warn_no_return and not unreachable:
if (defn.is_generator or
is_named_instance(self.return_types[-1], 'typing.AwaitableGenerator')):
return_type = self.get_generator_return_type(self.return_types[-1],
Expand Down Expand Up @@ -1083,7 +1079,7 @@ def is_unannotated_any(t: Type) -> bool:
code=codes.NO_UNTYPED_DEF)
elif fdef.is_generator:
if is_unannotated_any(self.get_generator_return_type(ret_type,
fdef.is_coroutine)):
fdef.is_coroutine)):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef,
code=codes.NO_UNTYPED_DEF)
elif fdef.is_coroutine and isinstance(ret_type, Instance):
Expand Down Expand Up @@ -2641,8 +2637,7 @@ def check_rvalue_count_in_assignment(self, lvalues: List[Lvalue], rvalue_count:
len(lvalues) - 1, context)
return False
elif rvalue_count != len(lvalues):
self.msg.wrong_number_values_to_unpack(rvalue_count,
len(lvalues), context)
self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues), context)
return False
return True

Expand Down Expand Up @@ -2896,8 +2891,7 @@ def check_lvalue(self, lvalue: Lvalue) -> Tuple[Optional[Type],
elif isinstance(lvalue, IndexExpr):
index_lvalue = lvalue
elif isinstance(lvalue, MemberExpr):
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue,
True)
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True)
self.store_type(lvalue, lvalue_type)
elif isinstance(lvalue, NameExpr):
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
Expand Down Expand Up @@ -4144,6 +4138,7 @@ def is_type_call(expr: CallExpr) -> bool:
"""Is expr a call to type with one argument?"""
return (refers_to_fullname(expr.callee, 'builtins.type')
and len(expr.args) == 1)

# exprs that are being passed into type
exprs_in_type_calls: List[Expression] = []
# type that is being compared to type(expr)
Expand Down Expand Up @@ -4194,6 +4189,7 @@ def combine_maps(list_maps: List[TypeMap]) -> TypeMap:
if d is not None:
result_map.update(d)
return result_map

if_map = combine_maps(if_maps)
# type(x) == T is only true when x has the same type as T, meaning
# that it can be false if x is an instance of a subclass of T. That means
Expand Down
10 changes: 0 additions & 10 deletions mypy/nullcontext.py

This file was deleted.

Empty file removed mypy/ordered_dict.py
Empty file.
66 changes: 30 additions & 36 deletions mypy/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
from typing import List, Optional, Iterator, Tuple

from mypy.backports import nullcontext
from mypy.nodes import TypeInfo, FuncBase


Expand Down Expand Up @@ -51,18 +52,30 @@ def current_function_name(self) -> Optional[str]:
"""Return the current function's short name if it exists"""
return self.function.name if self.function else None

def enter_file(self, prefix: str) -> None:
@contextmanager
def module_scope(self, prefix: str) -> Iterator[None]:
self.module = prefix
self.classes = []
self.function = None
self.ignored = 0
yield
assert self.module
self.module = None

def enter_function(self, fdef: FuncBase) -> None:
@contextmanager
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
if not self.function:
self.function = fdef
else:
# Nested functions are part of the topmost function target.
self.ignored += 1
yield
if self.ignored:
# Leave a scope that's included in the enclosing target.
self.ignored -= 1
else:
assert self.function
self.function = None

def enter_class(self, info: TypeInfo) -> None:
"""Enter a class target scope."""
Expand All @@ -72,53 +85,34 @@ def enter_class(self, info: TypeInfo) -> None:
# Classes within functions are part of the enclosing function target.
self.ignored += 1

def leave(self) -> None:
"""Leave the innermost scope (can be any kind of scope)."""
def leave_class(self) -> None:
"""Leave a class target scope."""
if self.ignored:
# Leave a scope that's included in the enclosing target.
self.ignored -= 1
elif self.function:
# Function is always the innermost target.
self.function = None
elif self.classes:
else:
assert self.classes
# Leave the innermost class.
self.classes.pop()
else:
# Leave module.
assert self.module
self.module = None

@contextmanager
def class_scope(self, info: TypeInfo) -> Iterator[None]:
self.enter_class(info)
yield
self.leave_class()

def save(self) -> SavedScope:
"""Produce a saved scope that can be entered with saved_scope()"""
assert self.module
# We only save the innermost class, which is sufficient since
# the rest are only needed for when classes are left.
cls = self.classes[-1] if self.classes else None
return (self.module, cls, self.function)

@contextmanager
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
self.enter_function(fdef)
yield
self.leave()

@contextmanager
def class_scope(self, info: TypeInfo) -> Iterator[None]:
self.enter_class(info)
yield
self.leave()
return self.module, cls, self.function

@contextmanager
def saved_scope(self, saved: SavedScope) -> Iterator[None]:
module, info, function = saved
self.enter_file(module)
if info:
self.enter_class(info)
if function:
self.enter_function(function)
yield
if function:
self.leave()
if info:
self.leave()
self.leave()
with self.module_scope(module):
with self.class_scope(info) if info else nullcontext():
with self.function_scope(function) if function else nullcontext():
yield
59 changes: 29 additions & 30 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,36 +529,35 @@ def file_context(self,
self.errors.set_file(file_node.path, file_node.fullname, scope=scope)
self.cur_mod_node = file_node
self.cur_mod_id = file_node.fullname
scope.enter_file(self.cur_mod_id)
self._is_stub_file = file_node.path.lower().endswith('.pyi')
self._is_typeshed_stub_file = is_typeshed_file(file_node.path)
self.globals = file_node.names
self.tvar_scope = TypeVarLikeScope()

self.named_tuple_analyzer = NamedTupleAnalyzer(options, self)
self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg)
self.enum_call_analyzer = EnumCallAnalyzer(options, self)
self.newtype_analyzer = NewTypeAnalyzer(options, self, self.msg)

# Counter that keeps track of references to undefined things potentially caused by
# incomplete namespaces.
self.num_incomplete_refs = 0

if active_type:
self.incomplete_type_stack.append(False)
scope.enter_class(active_type)
self.enter_class(active_type.defn.info)
for tvar in active_type.defn.type_vars:
self.tvar_scope.bind_existing(tvar)

yield

if active_type:
scope.leave()
self.leave_class()
self.type = None
self.incomplete_type_stack.pop()
scope.leave()
with scope.module_scope(self.cur_mod_id):
self._is_stub_file = file_node.path.lower().endswith('.pyi')
self._is_typeshed_stub_file = is_typeshed_file(file_node.path)
self.globals = file_node.names
self.tvar_scope = TypeVarLikeScope()

self.named_tuple_analyzer = NamedTupleAnalyzer(options, self)
self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg)
self.enum_call_analyzer = EnumCallAnalyzer(options, self)
self.newtype_analyzer = NewTypeAnalyzer(options, self, self.msg)

# Counter that keeps track of references to undefined things potentially caused by
# incomplete namespaces.
self.num_incomplete_refs = 0

if active_type:
self.incomplete_type_stack.append(False)
scope.enter_class(active_type)
self.enter_class(active_type.defn.info)
for tvar in active_type.defn.type_vars:
self.tvar_scope.bind_existing(tvar)

yield

if active_type:
scope.leave_class()
self.leave_class()
self.type = None
self.incomplete_type_stack.pop()
del self.options

#
Expand Down
5 changes: 2 additions & 3 deletions mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def __init__(self, errors: Errors, options: Options, is_typeshed_file: bool) ->

def visit_mypy_file(self, o: MypyFile) -> None:
self.errors.set_file(o.path, o.fullname, scope=self.scope)
self.scope.enter_file(o.fullname)
super().visit_mypy_file(o)
self.scope.leave()
with self.scope.module_scope(o.fullname):
super().visit_mypy_file(o)

def visit_func(self, defn: FuncItem) -> None:
if not self.recurse_into_functions:
Expand Down
Loading

0 comments on commit 725a24a

Please sign in to comment.