Skip to content

Commit

Permalink
Treat generators with await as async. (#12925)
Browse files Browse the repository at this point in the history
Treat generators with await as async.
  • Loading branch information
jhance authored Jun 2, 2022
1 parent 1636a05 commit d21c5ab
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
make_optional_type,
)
from mypy.semanal_enum import ENUM_BASES
from mypy.traverser import has_await_expression
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarType,
TupleType, TypedDictType, Instance, ErasedType, UnionType,
Expand Down Expand Up @@ -3798,8 +3799,8 @@ def visit_set_comprehension(self, e: SetComprehension) -> Type:

def visit_generator_expr(self, e: GeneratorExpr) -> Type:
# If any of the comprehensions use async for, the expression will return an async generator
# object
if any(e.is_async):
# object, or if the left-side expression uses await.
if any(e.is_async) or has_await_expression(e.left_expr):
typ = 'typing.AsyncGenerator'
# received type is always None in async generator expressions
additional_args: List[Type] = [NoneType()]
Expand Down
16 changes: 16 additions & 0 deletions mypy/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom,
LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr,
YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE,
Expression,
)


Expand Down Expand Up @@ -397,6 +398,21 @@ def has_yield_expression(fdef: FuncBase) -> bool:
return seeker.found


class AwaitSeeker(TraverserVisitor):
def __init__(self) -> None:
super().__init__()
self.found = False

def visit_await_expr(self, o: AwaitExpr) -> None:
self.found = True


def has_await_expression(expr: Expression) -> bool:
seeker = AwaitSeeker()
expr.accept(seeker)
return seeker.found


class ReturnCollector(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,14 @@ async with C() as x: # E: "async with" outside async function

[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

[case testAsyncGeneratorExpressionAwait]
from typing import AsyncGenerator

async def f() -> AsyncGenerator[int, None]:
async def g(x: int) -> int:
return x

return (await g(x) for x in [1, 2, 3])

[typing fixtures/typing-async.pyi]

0 comments on commit d21c5ab

Please sign in to comment.