diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 0005282d92a9..3c9dec13af70 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -109,11 +109,9 @@ def is_undefined(self, name: str) -> bool: branch = self.branches[-1] return name not in branch.may_be_defined and name not in branch.must_be_defined - def is_defined_in_different_branch(self, name: str) -> bool: + def is_defined_in_a_branch(self, name: str) -> bool: assert len(self.branches) > 0 - if not self.is_undefined(name): - return False - for b in self.branches[: len(self.branches) - 1]: + for b in self.branches: if name in b.must_be_defined or name in b.may_be_defined: return True return False @@ -213,7 +211,13 @@ def is_partially_defined(self, name: str) -> bool: def is_defined_in_different_branch(self, name: str) -> bool: """This will return true if a variable is defined in a branch that's not the current branch.""" assert len(self._scope().branch_stmts) > 0 - return self._scope().branch_stmts[-1].is_defined_in_different_branch(name) + stmt = self._scope().branch_stmts[-1] + if not stmt.is_undefined(name): + return False + for stmt in self._scope().branch_stmts: + if stmt.is_defined_in_a_branch(name): + return True + return False def is_undefined(self, name: str) -> bool: assert len(self._scope().branch_stmts) > 0 diff --git a/test-data/unit/check-partially-defined.test b/test-data/unit/check-partially-defined.test index e91e7aa65e7b..11aa30642314 100644 --- a/test-data/unit/check-partially-defined.test +++ b/test-data/unit/check-partially-defined.test @@ -285,6 +285,19 @@ def f1() -> None: else: y = x # No error. +def f2() -> None: + if int(): + x = 0 + elif int(): + y = x # E: Name "x" is used before definition + else: + y = x # E: Name "x" is used before definition + if int(): + z = x # E: Name "x" is used before definition + x = 1 + else: + x = 2 + w = x # No error. [case testDefinedDifferentBranchPartiallyDefined] # flags: --enable-error-code partially-defined --enable-error-code use-before-def @@ -295,11 +308,17 @@ def f0() -> None: if first_iter: first_iter = False x = 0 - else: + elif int(): # This is technically a false positive but mypy isn't smart enough for this yet. y = x # E: Name "x" may be undefined - z = x # E: Name "x" may be undefined - + else: + y = x # E: Name "x" may be undefined + if int(): + z = x # E: Name "x" may be undefined + x = 1 + else: + x = 2 + w = x # No error. def f1() -> None: while True: