Skip to content

Commit

Permalink
stubgen: Support yield from statements (#15271)
Browse files Browse the repository at this point in the history
Resolves #10744
  • Loading branch information
hamdanal authored May 21, 2023
1 parent c2d02a3 commit 2ede35f
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 12 deletions.
27 changes: 18 additions & 9 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@
report_missing,
walk_packages,
)
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
from mypy.traverser import (
all_yield_expressions,
has_return_statement,
has_yield_expression,
has_yield_from_expression,
)
from mypy.types import (
OVERLOAD_NAMES,
TPDICT_NAMES,
Expand Down Expand Up @@ -774,18 +779,22 @@ def visit_func_def(self, o: FuncDef) -> None:
retname = None # implicit Any
elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES:
retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name]
elif has_yield_expression(o):
elif has_yield_expression(o) or has_yield_from_expression(o):
self.add_typing_import("Generator")
yield_name = "None"
send_name = "None"
return_name = "None"
for expr, in_assignment in all_yield_expressions(o):
if expr.expr is not None and not self.is_none_expr(expr.expr):
self.add_typing_import("Incomplete")
yield_name = self.typing_name("Incomplete")
if in_assignment:
self.add_typing_import("Incomplete")
send_name = self.typing_name("Incomplete")
if has_yield_from_expression(o):
self.add_typing_import("Incomplete")
yield_name = send_name = self.typing_name("Incomplete")
else:
for expr, in_assignment in all_yield_expressions(o):
if expr.expr is not None and not self.is_none_expr(expr.expr):
self.add_typing_import("Incomplete")
yield_name = self.typing_name("Incomplete")
if in_assignment:
self.add_typing_import("Incomplete")
send_name = self.typing_name("Incomplete")
if has_return_statement(o):
self.add_typing_import("Incomplete")
return_name = self.typing_name("Incomplete")
Expand Down
36 changes: 36 additions & 0 deletions mypy/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,21 @@ def has_yield_expression(fdef: FuncBase) -> bool:
return seeker.found


class YieldFromSeeker(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
self.found = False

def visit_yield_from_expr(self, o: YieldFromExpr) -> None:
self.found = True


def has_yield_from_expression(fdef: FuncBase) -> bool:
seeker = YieldFromSeeker()
fdef.accept(seeker)
return seeker.found


class AwaitSeeker(TraverserVisitor):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -922,3 +937,24 @@ def all_yield_expressions(node: Node) -> list[tuple[YieldExpr, bool]]:
v = YieldCollector()
node.accept(v)
return v.yield_expressions


class YieldFromCollector(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
self.in_assignment = False
self.yield_from_expressions: list[tuple[YieldFromExpr, bool]] = []

def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
self.in_assignment = True
super().visit_assignment_stmt(stmt)
self.in_assignment = False

def visit_yield_from_expr(self, expr: YieldFromExpr) -> None:
self.yield_from_expressions.append((expr, self.in_assignment))


def all_yield_from_expressions(node: Node) -> list[tuple[YieldFromExpr, bool]]:
v = YieldFromCollector()
node.accept(v)
return v.yield_from_expressions
82 changes: 79 additions & 3 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,9 @@ def h1():
def h2():
yield
return "abc"
def h3():
yield
return None
def all():
x = yield 123
return "abc"
Expand All @@ -1242,6 +1245,7 @@ def f() -> Generator[Incomplete, None, None]: ...
def g() -> Generator[None, Incomplete, None]: ...
def h1() -> Generator[None, None, None]: ...
def h2() -> Generator[None, None, Incomplete]: ...
def h3() -> Generator[None, None, None]: ...
def all() -> Generator[Incomplete, Incomplete, Incomplete]: ...

[case testFunctionYieldsNone]
Expand Down Expand Up @@ -1270,6 +1274,69 @@ class Generator: ...

def f() -> _Generator[Incomplete, None, None]: ...

[case testGeneratorYieldFrom]
def g1():
yield from x
def g2():
y = yield from x
def g3():
yield from x
return
def g4():
yield from x
return None
def g5():
yield from x
return z

[out]
from _typeshed import Incomplete
from collections.abc import Generator

def g1() -> Generator[Incomplete, Incomplete, None]: ...
def g2() -> Generator[Incomplete, Incomplete, None]: ...
def g3() -> Generator[Incomplete, Incomplete, None]: ...
def g4() -> Generator[Incomplete, Incomplete, None]: ...
def g5() -> Generator[Incomplete, Incomplete, Incomplete]: ...

[case testGeneratorYieldAndYieldFrom]
def g1():
yield x1
yield from x2
def g2():
yield x1
y = yield from x2
def g3():
y = yield x1
yield from x2
def g4():
yield x1
yield from x2
return
def g5():
yield x1
yield from x2
return None
def g6():
yield x1
yield from x2
return z
def g7():
yield None
yield from x2

[out]
from _typeshed import Incomplete
from collections.abc import Generator

def g1() -> Generator[Incomplete, Incomplete, None]: ...
def g2() -> Generator[Incomplete, Incomplete, None]: ...
def g3() -> Generator[Incomplete, Incomplete, None]: ...
def g4() -> Generator[Incomplete, Incomplete, None]: ...
def g5() -> Generator[Incomplete, Incomplete, None]: ...
def g6() -> Generator[Incomplete, Incomplete, Incomplete]: ...
def g7() -> Generator[Incomplete, Incomplete, None]: ...

[case testCallable]
from typing import Callable

Expand Down Expand Up @@ -2977,13 +3044,17 @@ def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...

[case testNestedGenerator]
def f():
def f1():
def g():
yield 0

return 0
def f2():
def g():
yield from [0]
return 0
[out]
def f(): ...
def f1(): ...
def f2(): ...

[case testKnownMagicMethodsReturnTypes]
class Some:
Expand Down Expand Up @@ -3193,6 +3264,10 @@ def gen():
y = yield x
return z

def gen2():
y = yield from x
return z

class X(unknown_call("X", "a b")): ...
class Y(collections.namedtuple("Y", xx)): ...
[out]
Expand Down Expand Up @@ -3227,6 +3302,7 @@ TD2: _Incomplete
TD3: _Incomplete

def gen() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...

0 comments on commit 2ede35f

Please sign in to comment.