Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat generators with await as async. #12925

Merged
merged 3 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
make_optional_type,
)
from mypy.semanal_enum import ENUM_BASES
from mypy.traverser import has_await_expression
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarType,
TupleType, TypedDictType, Instance, ErasedType, UnionType,
Expand Down Expand Up @@ -3798,8 +3799,8 @@ def visit_set_comprehension(self, e: SetComprehension) -> Type:

def visit_generator_expr(self, e: GeneratorExpr) -> Type:
# If any of the comprehensions use async for, the expression will return an async generator
# object
if any(e.is_async):
# object, or if the left-side expression uses await.
if any(e.is_async) or has_await_expression(e.left_expr):
typ = 'typing.AsyncGenerator'
# received type is always None in async generator expressions
additional_args: List[Type] = [NoneType()]
Expand Down
16 changes: 16 additions & 0 deletions mypy/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom,
LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr,
YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE,
Expression,
)


Expand Down Expand Up @@ -397,6 +398,21 @@ def has_yield_expression(fdef: FuncBase) -> bool:
return seeker.found


class AwaitSeeker(TraverserVisitor):
def __init__(self) -> None:
super().__init__()
self.found = False

def visit_await_expr(self, o: AwaitExpr) -> None:
self.found = True


def has_await_expression(expr: Expression) -> bool:
seeker = AwaitSeeker()
expr.accept(seeker)
return seeker.found


class ReturnCollector(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,14 @@ async with C() as x: # E: "async with" outside async function

[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

[case testAsyncGeneratorExpressionAwait]
from typing import AsyncGenerator

async def f() -> AsyncGenerator[int, None]:
async def g(x: int) -> int:
return x

return (await g(x) for x in [1, 2, 3])

[typing fixtures/typing-async.pyi]