Skip to content

Commit

Permalink
Fix use of Literal with new typing_extensions (#628)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored May 23, 2023
1 parent 27fbd18 commit 471f33c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Unreleased

- Fix use of `Literal` types with `typing_extensions` 4.6.0 (#628)

## Version 0.10.1 (May 22, 2023)

- Fix errors with protocol matching on `typing_extensions` 4.6.0
Expand Down
50 changes: 25 additions & 25 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import qcore

import typing_inspect
from typing_extensions import ParamSpec, TypedDict
from typing_extensions import ParamSpec, TypedDict, get_origin, get_args

from . import type_evaluation

Expand Down Expand Up @@ -113,15 +113,12 @@

try:
from types import GenericAlias
from typing import get_args, get_origin # Python 3.9
except ImportError:
GenericAlias = None

def get_origin(obj: object) -> Any:
return None

def get_args(obj: object) -> Tuple[Any, ...]:
return ()
try:
from types import UnionType
except ImportError:
UnionType = None


CONTEXT_MANAGER_TYPES = (typing.ContextManager, contextlib.AbstractContextManager)
Expand Down Expand Up @@ -418,17 +415,20 @@ def _type_from_runtime(
return _value_of_origin_args(
origin, args, val, ctx, allow_unpack=origin is tuple
)
elif typing_inspect.is_literal_type(val):
args = typing_inspect.get_args(val)
if len(args) == 0:
origin = get_origin(val)
if is_typing_name(origin, "Literal"):
args = get_args(val)
if len(args) == 1:
return KnownValue(args[0])
else:
return unite_values(*[KnownValue(arg) for arg in args])
elif typing_inspect.is_union_type(val):
args = typing_inspect.get_args(val)
elif is_typing_name(origin, "Union") or (
UnionType is not None and origin is UnionType
):
args = get_args(val)
return unite_values(*[_type_from_runtime(arg, ctx) for arg in args])
elif typing_inspect.is_tuple_type(val):
args = typing_inspect.get_args(val)
elif origin is tuple or is_typing_name(origin, "Tuple"):
args = get_args(val)
if not args:
if val is tuple or val is Tuple:
return TypedValue(tuple)
Expand Down Expand Up @@ -490,8 +490,8 @@ def _type_from_runtime(
is_typeddict=is_typeddict,
allow_unpack=allow_unpack or origin is tuple or origin is Tuple,
)
elif typing_inspect.is_callable_type(val):
args = typing_inspect.get_args(val)
elif origin is Callable or is_typing_name(origin, "Callable"):
args = get_args(val)
return _value_of_origin_args(Callable, args, val, ctx)
elif val is AsynqCallable:
return CallableValue(Signature.make([ELLIPSIS_PARAM], is_asynq=True))
Expand All @@ -511,13 +511,14 @@ def _type_from_runtime(
if isinstance(val.__supertype__, type):
# NewType
return NewTypeValue(val)
elif typing_inspect.is_tuple_type(val.__supertype__):
super_origin = get_origin(val.__supertype__)
if super_origin is tuple or is_typing_name(super_origin, "Tuple"):
# TODO figure out how to make NewTypes over tuples work
return AnyValue(AnySource.inference)
else:
ctx.show_error(f"Invalid NewType {val}")
return AnyValue(AnySource.error)
elif typing_inspect.is_typevar(val):
elif is_typing_name(type(val), "TypeVar"):
tv = cast(TypeVar, val)
return make_type_var_value(tv, ctx)
elif is_instance_of_typing_name(val, "ParamSpec"):
Expand All @@ -528,7 +529,7 @@ def _type_from_runtime(
return ParamSpecKwargsValue(val.__origin__)
elif is_typing_name(val, "Final") or is_typing_name(val, "ClassVar"):
return AnyValue(AnySource.incomplete_annotation)
elif typing_inspect.is_classvar(val) or typing_inspect.is_final_type(val):
elif is_typing_name(origin, "ClassVar") or is_typing_name(origin, "Final"):
typ = val.__args__[0]
return _type_from_runtime(typ, ctx)
elif is_instance_of_typing_name(val, "_ForwardRef") or is_instance_of_typing_name(
Expand Down Expand Up @@ -571,12 +572,11 @@ def _type_from_runtime(
return AnyValue(AnySource.incomplete_annotation)
elif is_typing_name(val, "TypedDict"):
return KnownValue(TypedDict)
elif isinstance(origin, type):
return _maybe_typed_value(origin)
elif val is NamedTuple:
return TypedValue(tuple)
else:
origin = get_origin(val)
if isinstance(origin, type):
return _maybe_typed_value(origin)
elif val is NamedTuple:
return TypedValue(tuple)
ctx.show_error(f"Invalid type annotation {val}")
return AnyValue(AnySource.error)

Expand Down
12 changes: 12 additions & 0 deletions pyanalyze/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ def capybara(x: Literal[True], y: Literal[True, False]) -> None:
assert_is_value(x, KnownValue(True))
assert_is_value(y, MultiValuedValue([KnownValue(True), KnownValue(False)]))

@assert_passes()
def test_literal_in_union(self):
from typing import Union

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

def capybara(x: Union[int, Literal["epoch"], None]) -> None:
assert_is_value(x, TypedValue(int) | KnownValue("epoch") | KnownValue(None))

@assert_passes()
def test_contextmanager(self):
from contextlib import contextmanager
Expand Down

0 comments on commit 471f33c

Please sign in to comment.