From 2ede35fe24ad1c6c2444156c753609f6b7888064 Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Mon, 22 May 2023 00:10:53 +0200 Subject: [PATCH] stubgen: Support `yield from` statements (#15271) Resolves #10744 --- mypy/stubgen.py | 27 ++++++++---- mypy/traverser.py | 36 ++++++++++++++++ test-data/unit/stubgen.test | 82 +++++++++++++++++++++++++++++++++++-- 3 files changed, 133 insertions(+), 12 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 32dc6a615f8c..9ac919046aef 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -126,7 +126,12 @@ report_missing, walk_packages, ) -from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression +from mypy.traverser import ( + all_yield_expressions, + has_return_statement, + has_yield_expression, + has_yield_from_expression, +) from mypy.types import ( OVERLOAD_NAMES, TPDICT_NAMES, @@ -774,18 +779,22 @@ def visit_func_def(self, o: FuncDef) -> None: retname = None # implicit Any elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES: retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name] - elif has_yield_expression(o): + elif has_yield_expression(o) or has_yield_from_expression(o): self.add_typing_import("Generator") yield_name = "None" send_name = "None" return_name = "None" - for expr, in_assignment in all_yield_expressions(o): - if expr.expr is not None and not self.is_none_expr(expr.expr): - self.add_typing_import("Incomplete") - yield_name = self.typing_name("Incomplete") - if in_assignment: - self.add_typing_import("Incomplete") - send_name = self.typing_name("Incomplete") + if has_yield_from_expression(o): + self.add_typing_import("Incomplete") + yield_name = send_name = self.typing_name("Incomplete") + else: + for expr, in_assignment in all_yield_expressions(o): + if expr.expr is not None and not self.is_none_expr(expr.expr): + self.add_typing_import("Incomplete") + yield_name = self.typing_name("Incomplete") + if in_assignment: + self.add_typing_import("Incomplete") + send_name = self.typing_name("Incomplete") if has_return_statement(o): self.add_typing_import("Incomplete") return_name = self.typing_name("Incomplete") diff --git a/mypy/traverser.py b/mypy/traverser.py index 038d948522f0..2fcc376cfb7c 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -873,6 +873,21 @@ def has_yield_expression(fdef: FuncBase) -> bool: return seeker.found +class YieldFromSeeker(FuncCollectorBase): + def __init__(self) -> None: + super().__init__() + self.found = False + + def visit_yield_from_expr(self, o: YieldFromExpr) -> None: + self.found = True + + +def has_yield_from_expression(fdef: FuncBase) -> bool: + seeker = YieldFromSeeker() + fdef.accept(seeker) + return seeker.found + + class AwaitSeeker(TraverserVisitor): def __init__(self) -> None: super().__init__() @@ -922,3 +937,24 @@ def all_yield_expressions(node: Node) -> list[tuple[YieldExpr, bool]]: v = YieldCollector() node.accept(v) return v.yield_expressions + + +class YieldFromCollector(FuncCollectorBase): + def __init__(self) -> None: + super().__init__() + self.in_assignment = False + self.yield_from_expressions: list[tuple[YieldFromExpr, bool]] = [] + + def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None: + self.in_assignment = True + super().visit_assignment_stmt(stmt) + self.in_assignment = False + + def visit_yield_from_expr(self, expr: YieldFromExpr) -> None: + self.yield_from_expressions.append((expr, self.in_assignment)) + + +def all_yield_from_expressions(node: Node) -> list[tuple[YieldFromExpr, bool]]: + v = YieldFromCollector() + node.accept(v) + return v.yield_from_expressions diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 1834284ef48e..8c92e067b930 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -1231,6 +1231,9 @@ def h1(): def h2(): yield return "abc" +def h3(): + yield + return None def all(): x = yield 123 return "abc" @@ -1242,6 +1245,7 @@ def f() -> Generator[Incomplete, None, None]: ... def g() -> Generator[None, Incomplete, None]: ... def h1() -> Generator[None, None, None]: ... def h2() -> Generator[None, None, Incomplete]: ... +def h3() -> Generator[None, None, None]: ... def all() -> Generator[Incomplete, Incomplete, Incomplete]: ... [case testFunctionYieldsNone] @@ -1270,6 +1274,69 @@ class Generator: ... def f() -> _Generator[Incomplete, None, None]: ... +[case testGeneratorYieldFrom] +def g1(): + yield from x +def g2(): + y = yield from x +def g3(): + yield from x + return +def g4(): + yield from x + return None +def g5(): + yield from x + return z + +[out] +from _typeshed import Incomplete +from collections.abc import Generator + +def g1() -> Generator[Incomplete, Incomplete, None]: ... +def g2() -> Generator[Incomplete, Incomplete, None]: ... +def g3() -> Generator[Incomplete, Incomplete, None]: ... +def g4() -> Generator[Incomplete, Incomplete, None]: ... +def g5() -> Generator[Incomplete, Incomplete, Incomplete]: ... + +[case testGeneratorYieldAndYieldFrom] +def g1(): + yield x1 + yield from x2 +def g2(): + yield x1 + y = yield from x2 +def g3(): + y = yield x1 + yield from x2 +def g4(): + yield x1 + yield from x2 + return +def g5(): + yield x1 + yield from x2 + return None +def g6(): + yield x1 + yield from x2 + return z +def g7(): + yield None + yield from x2 + +[out] +from _typeshed import Incomplete +from collections.abc import Generator + +def g1() -> Generator[Incomplete, Incomplete, None]: ... +def g2() -> Generator[Incomplete, Incomplete, None]: ... +def g3() -> Generator[Incomplete, Incomplete, None]: ... +def g4() -> Generator[Incomplete, Incomplete, None]: ... +def g5() -> Generator[Incomplete, Incomplete, None]: ... +def g6() -> Generator[Incomplete, Incomplete, Incomplete]: ... +def g7() -> Generator[Incomplete, Incomplete, None]: ... + [case testCallable] from typing import Callable @@ -2977,13 +3044,17 @@ def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ... def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ... [case testNestedGenerator] -def f(): +def f1(): def g(): yield 0 - + return 0 +def f2(): + def g(): + yield from [0] return 0 [out] -def f(): ... +def f1(): ... +def f2(): ... [case testKnownMagicMethodsReturnTypes] class Some: @@ -3193,6 +3264,10 @@ def gen(): y = yield x return z +def gen2(): + y = yield from x + return z + class X(unknown_call("X", "a b")): ... class Y(collections.namedtuple("Y", xx)): ... [out] @@ -3227,6 +3302,7 @@ TD2: _Incomplete TD3: _Incomplete def gen() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ... +def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ... class X(_Incomplete): ... class Y(_Incomplete): ...