Skip to content

Commit

Permalink
Initial ParamSpec support (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Dec 20, 2021
1 parent 3bb42b6 commit cd3ac77
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 73 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add basic support for `ParamSpec` (#352)
- Fix error on use of `AbstractAsyncContextManager` (#350)
- Check `with` and `async with` statements (#344)
- Improve type compatibility between generics and literals (#346)
Expand Down
168 changes: 128 additions & 40 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""
from dataclasses import dataclass, InitVar, field
import typing

import typing_inspect
import qcore
import ast
Expand All @@ -46,6 +47,7 @@
Union,
TYPE_CHECKING,
)
from typing_extensions import ParamSpec

from .error_code import ErrorCode
from .extensions import (
Expand Down Expand Up @@ -363,6 +365,8 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
else:
constraints = ()
return TypeVarValue(tv, bound=bound, constraints=constraints)
elif is_instance_of_typing_name(val, "ParamSpec"):
return TypeVarValue(val, is_paramspec=True)
elif is_typing_name(val, "Final") or is_typing_name(val, "ClassVar"):
return AnyValue(AnySource.incomplete_annotation)
elif typing_inspect.is_classvar(val) or typing_inspect.is_final_type(val):
Expand Down Expand Up @@ -405,30 +409,14 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
[TypeGuardExtension(_type_from_runtime(val.__type__, ctx))],
)
elif isinstance(val, AsynqCallable):
arg_types = val.args
return_type = val.return_type
if arg_types is Ellipsis:
return CallableValue(
Signature.make(
[],
_type_from_runtime(return_type, ctx),
is_ellipsis_args=True,
is_asynq=True,
)
)
if not isinstance(arg_types, tuple):
ctx.show_error("Invalid arguments to AsynqCallable")
return AnyValue(AnySource.error)
params = [
SigParameter(
f"__arg{i}",
kind=ParameterKind.POSITIONAL_ONLY,
annotation=_type_from_runtime(arg, ctx),
)
for i, arg in enumerate(arg_types)
]
params, is_ellipsis_args = _callable_args_from_runtime(
val.args, "AsynqCallable", ctx
)
sig = Signature.make(
params, _type_from_runtime(return_type, ctx), is_asynq=True
params,
_type_from_runtime(val.return_type, ctx),
is_asynq=True,
is_ellipsis_args=is_ellipsis_args,
)
return CallableValue(sig)
elif isinstance(val, ExternalType):
Expand Down Expand Up @@ -458,6 +446,66 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
return AnyValue(AnySource.error)


def _callable_args_from_runtime(
arg_types: Any, label: str, ctx: Context
) -> Tuple[Sequence[SigParameter], bool]:
if arg_types is Ellipsis or arg_types == [Ellipsis]:
return [], True
elif type(arg_types) in (tuple, list):
if len(arg_types) == 1:
(arg,) = arg_types
if arg is Ellipsis:
return [], True
elif is_typing_name(getattr(arg, "__origin__", None), "Concatenate"):
return _args_from_concatenate(arg, ctx)
elif is_instance_of_typing_name(arg, "ParamSpec"):
param_spec = TypeVarValue(arg, is_paramspec=True)
param = SigParameter(
"__P", kind=ParameterKind.PARAM_SPEC, annotation=param_spec
)
return [param], False
types = [_type_from_runtime(arg, ctx) for arg in arg_types]
params = [
SigParameter(
f"__arg{i}",
kind=ParameterKind.PARAM_SPEC
if isinstance(typ, TypeVarValue) and typ.is_paramspec
else ParameterKind.POSITIONAL_ONLY,
annotation=typ,
)
for i, typ in enumerate(types)
]
return params, False
elif is_instance_of_typing_name(arg_types, "ParamSpec"):
param_spec = TypeVarValue(arg_types, is_paramspec=True)
param = SigParameter(
"__P", kind=ParameterKind.PARAM_SPEC, annotation=param_spec
)
return [param], False
elif is_typing_name(getattr(arg_types, "__origin__", None), "Concatenate"):
return _args_from_concatenate(arg_types, ctx)
else:
ctx.show_error(f"Invalid arguments to {label}: {arg_types!r}")
return [], True


def _args_from_concatenate(
concatenate: Any, ctx: Context
) -> Tuple[Sequence[SigParameter], bool]:
types = [_type_from_runtime(arg, ctx) for arg in concatenate.__args__]
params = [
SigParameter(
f"__arg{i}",
kind=ParameterKind.PARAM_SPEC
if i == len(types) - 1
else ParameterKind.POSITIONAL_ONLY,
annotation=annotation,
)
for i, annotation in enumerate(types)
]
return params, False


def _get_typeddict_value(
value: Value,
ctx: Context,
Expand Down Expand Up @@ -815,6 +863,21 @@ def visit_Call(self, node: ast.Call) -> Optional[Value]:
return None
tv = TypeVar(name_val.val)
return TypeVarValue(tv, bound, tuple(constraints))
elif is_typing_name(func.val, "ParamSpec"):
arg_values = [self.visit(arg) for arg in node.args]
kwarg_values = [(kw.arg, self.visit(kw.value)) for kw in node.keywords]
if not arg_values:
self.ctx.show_error("ParamSpec() requires at least one argument")
return None
name_val = arg_values[0]
if not isinstance(name_val, KnownValue):
self.ctx.show_error("ParamSpec name must be a literal")
return None
for name, _ in kwarg_values:
self.ctx.show_error(f"Unrecognized ParamSpec kwarg {name}")
return None
tv = ParamSpec(name_val.val)
return TypeVarValue(tv, is_paramspec=True)
elif isinstance(func.val, type):
if func.val is object:
return AnyValue(AnySource.inference)
Expand Down Expand Up @@ -847,26 +910,19 @@ def _value_of_origin_args(
elif origin is typing.Union:
return unite_values(*[_type_from_runtime(arg, ctx) for arg in args])
elif origin is Callable or origin is typing.Callable:
if len(args) == 2 and args[0] is Ellipsis:
return CallableValue(
Signature.make(
[], _type_from_runtime(args[1], ctx), is_ellipsis_args=True
)
)
elif len(args) == 0:
if len(args) == 0:
return TypedValue(Callable)
*arg_types, return_type = args
if len(arg_types) == 1 and isinstance(arg_types[0], list):
arg_types = arg_types[0]
params = [
SigParameter(
f"__arg{i}",
kind=ParameterKind.POSITIONAL_ONLY,
annotation=_type_from_runtime(arg, ctx, is_typeddict=True),
)
for i, arg in enumerate(arg_types)
]
sig = Signature.make(params, _type_from_runtime(return_type, ctx))
params, is_ellipsis_args = _callable_args_from_runtime(
arg_types, "Callable", ctx
)
sig = Signature.make(
params,
_type_from_runtime(return_type, ctx),
is_ellipsis_args=is_ellipsis_args,
)
return CallableValue(sig)
elif is_typing_name(origin, "Annotated"):
origin, metadata = args
Expand Down Expand Up @@ -966,8 +1022,40 @@ def _make_callable_from_value(
]
sig = Signature.make(params, return_annotation, is_asynq=is_asynq)
return CallableValue(sig)
elif isinstance(args, KnownValue) and is_instance_of_typing_name(
args.val, "ParamSpec"
):
annotation = TypeVarValue(args.val, is_paramspec=True)
params = [
SigParameter("__P", kind=ParameterKind.PARAM_SPEC, annotation=annotation)
]
sig = Signature.make(params, return_annotation, is_asynq=is_asynq)
return CallableValue(sig)
elif isinstance(args, TypeVarValue) and args.is_paramspec:
params = [SigParameter("__P", kind=ParameterKind.PARAM_SPEC, annotation=args)]
sig = Signature.make(params, return_annotation, is_asynq=is_asynq)
return CallableValue(sig)
elif (
isinstance(args, _SubscriptedValue)
and isinstance(args.root, KnownValue)
and is_typing_name(args.root.val, "Concatenate")
):
annotations = [_type_from_value(arg, ctx) for arg in args.members]
params = [
SigParameter(
f"__arg{i}",
kind=ParameterKind.PARAM_SPEC
if i == len(annotations) - 1
else ParameterKind.POSITIONAL_ONLY,
annotation=annotation,
)
for i, annotation in enumerate(annotations)
]
sig = Signature.make(params, return_annotation, is_asynq=is_asynq)
return CallableValue(sig)
else:
return AnyValue(AnySource.inference)
ctx.show_error(f"Unrecognized Callable type argument {args}")
return AnyValue(AnySource.error)


def _make_annotated(origin: Value, metadata: Sequence[Value], ctx: Context) -> Value:
Expand Down
96 changes: 70 additions & 26 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AnySource,
AnyValue,
AsyncTaskIncompleteValue,
CallableValue,
CanAssignContext,
GenericValue,
HasAttrExtension,
Expand Down Expand Up @@ -234,6 +235,7 @@ class ParameterKind(enum.Enum):
VAR_POSITIONAL = 2
KEYWORD_ONLY = 3
VAR_KEYWORD = 4
PARAM_SPEC = 5


@dataclass
Expand Down Expand Up @@ -299,7 +301,7 @@ def __str__(self) -> str:

if kind is ParameterKind.VAR_POSITIONAL:
formatted = "*" + formatted
elif kind is ParameterKind.VAR_KEYWORD:
elif kind is ParameterKind.VAR_KEYWORD or kind is ParameterKind.PARAM_SPEC:
formatted = "**" + formatted

return formatted
Expand Down Expand Up @@ -919,6 +921,7 @@ def can_assign(
kwargs_annotation = None
consumed_positional = set()
consumed_keyword = set()
consumed_paramspec = False
for i, my_param in enumerate(self.parameters.values()):
my_annotation = my_param.get_annotation()
if my_param.kind is ParameterKind.POSITIONAL_ONLY:
Expand Down Expand Up @@ -1078,29 +1081,47 @@ def can_assign(
[tv_map],
)
tv_maps.append(tv_map)
elif my_param.kind is ParameterKind.PARAM_SPEC:
remaining = [
param
for param in other.parameters.values()
if param.name not in consumed_positional
and param.name not in consumed_keyword
]
new_sig = Signature.make(remaining)
assert isinstance(my_annotation, TypeVarValue)
tv_maps.append({my_annotation.typevar: CallableValue(new_sig)})
consumed_paramspec = True
else:
assert False, f"unhandled param {my_param}"

for param in their_params:
if (
param.kind is ParameterKind.VAR_POSITIONAL
or param.kind is ParameterKind.VAR_KEYWORD
):
continue # ok if they have extra *args or **kwargs
elif param.default is not None:
continue
elif param.kind is ParameterKind.POSITIONAL_ONLY:
if param.name not in consumed_positional:
return CanAssignError(
f"takes extra positional-only parameter {param.name!r}"
)
elif param.kind is ParameterKind.POSITIONAL_OR_KEYWORD:
if not consumed_paramspec:
for param in their_params:
if (
param.name not in consumed_positional
and param.name not in consumed_keyword
param.kind is ParameterKind.VAR_POSITIONAL
or param.kind is ParameterKind.VAR_KEYWORD
):
return CanAssignError(f"takes extra parameter {param.name!r}")
elif param.kind is ParameterKind.KEYWORD_ONLY:
if param.name not in consumed_keyword:
return CanAssignError(f"takes extra parameter {param.name!r}")
continue # ok if they have extra *args or **kwargs
elif param.default is not None:
continue
elif param.kind is ParameterKind.POSITIONAL_ONLY:
if param.name not in consumed_positional:
return CanAssignError(
f"takes extra positional-only parameter {param.name!r}"
)
elif param.kind is ParameterKind.POSITIONAL_OR_KEYWORD:
if (
param.name not in consumed_positional
and param.name not in consumed_keyword
):
return CanAssignError(f"takes extra parameter {param.name!r}")
elif param.kind is ParameterKind.KEYWORD_ONLY:
if param.name not in consumed_keyword:
return CanAssignError(f"takes extra parameter {param.name!r}")
elif param.kind is ParameterKind.PARAM_SPEC:
return CanAssignError(f"takes extra ParamSpec {param!r}")
else:
assert False, f"unhandled param {param}"

return unify_typevar_maps(tv_maps)

Expand All @@ -1111,17 +1132,40 @@ def get_param_of_kind(self, kind: ParameterKind) -> Optional[SigParameter]:
return None

def substitute_typevars(self, typevars: TypeVarMap) -> "Signature":
params = []
is_ellipsis_args = self.is_ellipsis_args
for name, param in self.parameters.items():
if param.kind is ParameterKind.PARAM_SPEC:
assert isinstance(param.annotation, TypeVarValue)
tv = param.annotation.typevar
if tv in typevars:
new_val = typevars[tv].substitute_typevars(typevars)
if isinstance(new_val, TypeVarValue):
assert new_val.is_paramspec, new_val
new_param = SigParameter(
param.name, param.kind, annotation=new_val
)
params.append((name, new_param))
elif isinstance(new_val, AnyValue):
is_ellipsis_args = True
params = []
break
else:
assert isinstance(new_val, CallableValue), new_val
assert isinstance(new_val.signature, Signature), new_val
params += list(new_val.signature.parameters.items())
else:
params.append((name, param))
else:
params.append((name, param.substitute_typevars(typevars)))
return Signature(
{
name: param.substitute_typevars(typevars)
for name, param in self.parameters.items()
},
dict(params),
self.return_value.substitute_typevars(typevars),
impl=self.impl,
callable=self.callable,
is_asynq=self.is_asynq,
has_return_annotation=self.has_return_annotation,
is_ellipsis_args=self.is_ellipsis_args,
is_ellipsis_args=is_ellipsis_args,
allow_call=self.allow_call,
)

Expand Down
Loading

0 comments on commit cd3ac77

Please sign in to comment.