Skip to content

Commit

Permalink
Partial support for PEP 695-style type aliases (#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Sep 24, 2023
1 parent 201fbc9 commit 7c661f5
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Partial support for PEP 695-style type aliases (#690)
- Add option to disable all error codes (#659)
- Add hacky fix for bugs with hashability on type objects (#689)
- Show an error on calls to `typing.Any` (#688)
Expand Down
43 changes: 43 additions & 0 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
_HashableValue,
DictIncompleteValue,
KVPair,
TypeAlias,
TypeAliasValue,
annotate_value,
AnnotatedValue,
AnySource,
Expand Down Expand Up @@ -202,6 +204,14 @@ def get_attribute(self, root_value: Value, node: ast.Attribute) -> Value:
self.show_error(f"Cannot resolve annotation {root_value}", node=node)
return AnyValue(AnySource.error)

def get_type_alias(
self,
key: object,
evaluator: Callable[[], Value],
evaluate_type_params: Callable[[], Sequence[TypeVarLike]],
) -> TypeAlias:
return TypeAlias(evaluator, evaluate_type_params)


@dataclass
class RuntimeEvaluator(type_evaluation.Evaluator, Context):
Expand Down Expand Up @@ -489,6 +499,13 @@ def _type_from_runtime(
return _eval_forward_ref(
val.__forward_arg__, ctx, is_typeddict=is_typeddict
)
elif is_instance_of_typing_name(val, "TypeAliasType"):
alias = ctx.get_type_alias(
val,
lambda: type_from_runtime(val.__value__, ctx=ctx),
lambda: val.__type_params__,
)
return TypeAliasValue(val.__name__, val.__module__, alias)
elif val is Ellipsis:
# valid in Callable[..., ]
return AnyValue(AnySource.explicit)
Expand Down Expand Up @@ -868,6 +885,21 @@ def get_name(self, node: ast.Name) -> Value:
)
return AnyValue(AnySource.error)

def get_type_alias(
self,
key: object,
evaluator: Callable[[], Value],
evaluate_type_params: Callable[[], Sequence[TypeVarLike]],
) -> TypeAlias:
if self.visitor is not None:
cache = self.visitor.checker.type_alias_cache
if key in cache:
return cache[key]
alias = super().get_type_alias(key, evaluator, evaluate_type_params)
cache[key] = alias
return alias
return super().get_type_alias(key, evaluator, evaluate_type_params)


@dataclass(frozen=True)
class _SubscriptedValue(Value):
Expand Down Expand Up @@ -1177,6 +1209,17 @@ def _value_of_origin_args(
ctx.show_error("Unpack requires a single argument")
return AnyValue(AnySource.error)
return UnpackedValue(_type_from_runtime(args[0], ctx))
elif is_instance_of_typing_name(origin, "TypeAliasType"):
args_vals = [_type_from_runtime(val, ctx) for val in args]
alias_object = cast(Any, origin)
alias = ctx.get_type_alias(
val,
lambda: type_from_runtime(alias_object.__value__, ctx=ctx),
lambda: alias_object.__type_params__,
)
return TypeAliasValue(
alias_object.__name__, alias_object.__module__, alias, tuple(args_vals)
)
else:
ctx.show_error(
f"Unrecognized annotation {origin}[{', '.join(map(repr, args))}]"
Expand Down
18 changes: 13 additions & 5 deletions pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
KnownValueWithTypeVars,
MultiValuedValue,
SyntheticModuleValue,
TypeAliasValue,
annotate_value,
set_self,
SubclassValue,
Expand Down Expand Up @@ -94,12 +95,19 @@ def get_generic_bases(
return {}


def get_root_value(val: Value) -> Value:
if isinstance(val, AnnotatedValue):
return get_root_value(val.value)
elif isinstance(val, TypeAliasValue):
return get_root_value(val.get_value())
elif isinstance(val, TypeVarValue):
return get_root_value(val.get_fallback_value())
else:
return val


def get_attribute(ctx: AttrContext) -> Value:
root_value = ctx.root_value
if isinstance(root_value, TypeVarValue):
root_value = root_value.get_fallback_value()
elif isinstance(root_value, AnnotatedValue):
root_value = root_value.value
root_value = get_root_value(ctx.root_value)
if isinstance(root_value, KnownValue):
attribute_value = _get_attribute_from_known(root_value.val, ctx)
elif isinstance(root_value, TypedValue):
Expand Down
2 changes: 2 additions & 0 deletions pyanalyze/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
AnnotatedValue,
AnyValue,
CallableValue,
TypeAlias,
flatten_values,
is_union,
KnownValue,
Expand Down Expand Up @@ -101,6 +102,7 @@ class Checker:
default_factory=list
)
vnv_map: Dict[str, VariableNameValue] = field(default_factory=dict)
type_alias_cache: Dict[object, TypeAlias] = field(default_factory=dict)
_should_exclude_any: bool = False
_has_used_any_match: bool = False

Expand Down
69 changes: 68 additions & 1 deletion pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def can_assign(self, other: "Value", ctx: "CanAssignContext") -> "CanAssign":
return can_assign
bounds_maps.append(can_assign)
return unify_bounds_maps(bounds_maps)
elif isinstance(other, (AnnotatedValue, TypeVarValue)):
elif isinstance(other, (AnnotatedValue, TypeVarValue, TypeAliasValue)):
return other.can_be_assigned(self, ctx)
elif (
isinstance(other, UnboundMethodValue)
Expand Down Expand Up @@ -413,6 +413,73 @@ def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign:
VOID = VoidValue()


@dataclass
class TypeAlias:
evaluator: Callable[[], Value]
"""Callable that evaluates the value."""
evaluate_type_params: Callable[[], Sequence[TypeVarLike]]
"""Callable that evaluates the type parameters."""
evaluated_value: Optional[Value] = None
"""Value that the type alias evaluates to."""
type_params: Optional[Sequence[TypeVarLike]] = None
"""Type parameters of the type alias."""

def get_value(self) -> Value:
if self.evaluated_value is None:
self.evaluated_value = self.evaluator()
return self.evaluated_value

def get_type_params(self) -> Sequence[TypeVarLike]:
if self.type_params is None:
self.type_params = self.evaluate_type_params()
return self.type_params


@dataclass(frozen=True)
class TypeAliasValue(Value):
"""Value representing a type alias."""

name: str
"""Name of the type alias."""
module: str
"""Module where the type alias is defined."""
alias: TypeAlias = field(compare=False, hash=False)
type_arguments: Sequence[Value] = ()

def get_value(self) -> Value:
val = self.alias.get_value()
if self.type_arguments:
type_params = self.alias.get_type_params()
if len(type_params) != len(self.type_arguments):
# TODO this should be an error
return AnyValue(AnySource.inference)
typevars = {
type_param: arg
for type_param, arg in zip(type_params, self.type_arguments)
}
val = val.substitute_typevars(typevars)
return val

def is_type(self, typ: type) -> bool:
return self.get_value().is_type(typ)

def get_type(self) -> Optional[type]:
return self.get_value().get_type()

def get_type_value(self) -> Value:
return self.get_value().get_type_value()

def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign:
if isinstance(other, TypeAliasValue) and self.alias is other.alias:
return {}
return self.get_value().can_assign(other, ctx)

def can_be_assigned(self, other: Value, ctx: CanAssignContext) -> CanAssign:
if isinstance(other, TypeAliasValue) and self.alias is other.alias:
return {}
return other.can_assign(self.get_value(), ctx)


@dataclass(frozen=True)
class UninitializedValue(Value):
"""Value for variables that have not been initialized.
Expand Down

0 comments on commit 7c661f5

Please sign in to comment.