diff --git a/mypy/checker.py b/mypy/checker.py index c7de4911501a..6679e3d4aff7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -843,6 +843,10 @@ def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Typ if isinstance(return_type, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=return_type) + elif isinstance(return_type, UnionType): + return make_simplified_union( + [self.get_generator_yield_type(item, is_coroutine) for item in return_type.items] + ) elif not self.is_generator_return_type( return_type, is_coroutine ) and not self.is_async_generator_return_type(return_type): @@ -873,6 +877,10 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T if isinstance(return_type, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=return_type) + elif isinstance(return_type, UnionType): + return make_simplified_union( + [self.get_generator_receive_type(item, is_coroutine) for item in return_type.items] + ) elif not self.is_generator_return_type( return_type, is_coroutine ) and not self.is_async_generator_return_type(return_type): @@ -912,6 +920,10 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty if isinstance(return_type, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=return_type) + elif isinstance(return_type, UnionType): + return make_simplified_union( + [self.get_generator_return_type(item, is_coroutine) for item in return_type.items] + ) elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or # Awaitable) return type, anything is permissible. diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 195e70cf5880..d53cba2fc642 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -925,3 +925,21 @@ async def f() -> AsyncGenerator[int, None]: [builtins fixtures/async_await.pyi] [typing fixtures/typing-async.pyi] + +[case testAwaitUnion] +from typing import overload, Union + +class A: ... +class B: ... + +@overload +async def foo(x: A) -> B: ... +@overload +async def foo(x: B) -> A: ... +async def foo(x): ... + +async def bar(x: Union[A, B]) -> None: + reveal_type(await foo(x)) # N: Revealed type is "Union[__main__.B, __main__.A]" + +[builtins fixtures/async_await.pyi] +[typing fixtures/typing-async.pyi] diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 4be5060996e2..3450f8593d27 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -2206,3 +2206,12 @@ def foo(): x: int = "no" # N: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs y = "no" # type: int # N: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs z: int # N: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs + +[case testGeneratorUnion] +from typing import Generator, Union + +class A: pass +class B: pass + +def foo(x: int) -> Union[Generator[A, None, None], Generator[B, None, None]]: + yield x # E: Incompatible types in "yield" (actual type "int", expected type "Union[A, B]")