Skip to content

Commit

Permalink
Add support for Unpack in args and kwargs (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Apr 17, 2022
1 parent 5080a78 commit 01ae202
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 98 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Add support for use of the `Unpack` operator to
annotate heterogeneous `*args` and `**kwargs` parameters (#523)
- Detect incompatible types for some calls to `list.append`,
`list.extend`, `list.__add__`, and `set.add` (#522)
- Optimize local variables with very complex inferred types (#521)
Expand Down
107 changes: 52 additions & 55 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
SequenceValue,
TypeGuardExtension,
TypedValue,
UnpackedValue,
annotate_value,
unite_values,
Value,
Expand Down Expand Up @@ -278,6 +279,8 @@ def type_from_runtime(
node: Optional[ast.AST] = None,
globals: Optional[Mapping[str, object]] = None,
ctx: Optional[Context] = None,
*,
allow_unpack: bool = False,
) -> Value:
"""Given a runtime annotation object, return a
:class:`Value <pyanalyze.value.Value>`.
Expand All @@ -297,19 +300,23 @@ def type_from_runtime(
:param ctx: :class:`Context` to use for evaluation.
:param allow_unpack: Whether to allow `Unpack` types.
"""

if ctx is None:
ctx = _DefaultContext(visitor, node, globals)
return _type_from_runtime(val, ctx)
return _type_from_runtime(val, ctx, allow_unpack=allow_unpack)


def type_from_value(
value: Value,
visitor: Optional["NameCheckVisitor"] = None,
node: Optional[ast.AST] = None,
ctx: Optional[Context] = None,
*,
is_typeddict: bool = False,
allow_unpack: bool = False,
) -> Value:
"""Given a :class:`Value <pyanalyze.value.Value` representing an annotation,
return a :class:`Value <pyanalyze.value.Value>` representing the type.
Expand All @@ -336,7 +343,9 @@ def type_from_value(
"""
if ctx is None:
ctx = _DefaultContext(visitor, node)
return _type_from_value(value, ctx, is_typeddict=is_typeddict)
return _type_from_value(
value, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)


def value_from_ast(
Expand All @@ -355,20 +364,20 @@ def _type_from_ast(
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
allow_unpack: bool = False,
) -> Value:
val = value_from_ast(node, ctx)
return _type_from_value(
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)


def _type_from_runtime(
val: Any, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False
val: Any, ctx: Context, *, is_typeddict: bool = False, allow_unpack: bool = False
) -> Value:
if isinstance(val, str):
return _eval_forward_ref(
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)
elif isinstance(val, tuple):
# This happens under some Python versions for types
Expand All @@ -383,16 +392,14 @@ def _type_from_runtime(
args = (val[1],)
else:
args = val[1:]
return _value_of_origin_args(
origin, args, val, ctx, unpack_allowed=unpack_allowed
)
return _value_of_origin_args(origin, args, val, ctx, allow_unpack=allow_unpack)
elif GenericAlias is not None and isinstance(val, GenericAlias):
origin = get_origin(val)
args = get_args(val)
if origin is tuple and not args:
return SequenceValue(tuple, [])
return _value_of_origin_args(
origin, args, val, ctx, unpack_allowed=origin is tuple
origin, args, val, ctx, allow_unpack=origin is tuple
)
elif typing_inspect.is_literal_type(val):
args = typing_inspect.get_args(val)
Expand All @@ -417,7 +424,8 @@ def _type_from_runtime(
else:
return _make_sequence_value(
tuple,
[_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args],
[_type_from_runtime(arg, ctx, allow_unpack=True) for arg in args],
ctx,
)
elif is_instance_of_typing_name(val, "_TypedDictMeta"):
required_keys = getattr(val, "__required_keys__", None)
Expand Down Expand Up @@ -464,7 +472,7 @@ def _type_from_runtime(
val,
ctx,
is_typeddict=is_typeddict,
unpack_allowed=unpack_allowed or origin is tuple or origin is Tuple,
allow_unpack=allow_unpack or origin is tuple or origin is Tuple,
)
elif typing_inspect.is_callable_type(val):
args = typing_inspect.get_args(val)
Expand Down Expand Up @@ -568,8 +576,8 @@ def _type_from_runtime(
return AnyValue(AnySource.error)
# Also 3.6 only.
elif is_instance_of_typing_name(val, "_Unpack"):
if unpack_allowed:
return _make_unpacked_value(_type_from_runtime(val.__type__, ctx), ctx)
if allow_unpack:
return UnpackedValue(_type_from_runtime(val.__type__, ctx))
else:
ctx.show_error("Unpack[] used in unsupported context")
return AnyValue(AnySource.error)
Expand Down Expand Up @@ -622,7 +630,7 @@ def _callable_args_from_runtime(
types = [_type_from_runtime(arg, ctx) for arg in arg_types]
params = [
SigParameter(
f"__arg{i}",
f"@{i}",
kind=ParameterKind.PARAM_SPEC
if isinstance(typ, TypeVarValue) and typ.is_paramspec
else ParameterKind.POSITIONAL_ONLY,
Expand All @@ -648,7 +656,7 @@ def _args_from_concatenate(concatenate: Any, ctx: Context) -> Sequence[SigParame
types = [_type_from_runtime(arg, ctx) for arg in concatenate.__args__]
params = [
SigParameter(
f"__arg{i}",
f"@{i}",
kind=ParameterKind.PARAM_SPEC
if i == len(types) - 1
else ParameterKind.POSITIONAL_ONLY,
Expand Down Expand Up @@ -677,7 +685,7 @@ def _get_typeddict_value(


def _eval_forward_ref(
val: str, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False
val: str, ctx: Context, *, is_typeddict: bool = False, allow_unpack: bool = False
) -> Value:
try:
tree = ast.parse(val, mode="eval")
Expand All @@ -686,7 +694,7 @@ def _eval_forward_ref(
return AnyValue(AnySource.error)
else:
return _type_from_ast(
tree.body, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
tree.body, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)


Expand All @@ -695,19 +703,19 @@ def _type_from_value(
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
allow_unpack: bool = False,
) -> Value:
if isinstance(value, KnownValue):
return _type_from_runtime(
value.val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
value.val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)
elif isinstance(value, TypeVarValue):
return value
elif isinstance(value, MultiValuedValue):
return unite_values(
*[
_type_from_value(
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)
for val in value.vals
]
Expand All @@ -720,7 +728,7 @@ def _type_from_value(
value.members,
ctx,
is_typeddict=is_typeddict,
unpack_allowed=unpack_allowed,
allow_unpack=allow_unpack,
)
elif isinstance(value, AnyValue):
return value
Expand All @@ -740,7 +748,7 @@ def _type_from_subscripted_value(
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
allow_unpack: bool = False,
) -> Value:
if isinstance(root, GenericValue):
if len(root.args) == len(members):
Expand All @@ -758,7 +766,7 @@ def _type_from_subscripted_value(
members,
ctx,
is_typeddict=is_typeddict,
unpack_allowed=unpack_allowed,
allow_unpack=allow_unpack,
)
for subval in root.vals
]
Expand Down Expand Up @@ -800,7 +808,8 @@ def _type_from_subscripted_value(
else:
return _make_sequence_value(
tuple,
[_type_from_value(arg, ctx, unpack_allowed=True) for arg in members],
[_type_from_value(arg, ctx, allow_unpack=True) for arg in members],
ctx,
)
elif root is typing.Optional:
if len(members) != 1:
Expand Down Expand Up @@ -840,13 +849,13 @@ def _type_from_subscripted_value(
return AnyValue(AnySource.error)
return Pep655Value(False, _type_from_value(members[0], ctx))
elif is_typing_name(root, "Unpack"):
if not unpack_allowed:
if not allow_unpack:
ctx.show_error("Unpack[] used in unsupported context")
return AnyValue(AnySource.error)
if len(members) != 1:
ctx.show_error("Unpack requires a single argument")
return AnyValue(AnySource.error)
return _make_unpacked_value(_type_from_value(members[0], ctx), ctx)
return UnpackedValue(_type_from_value(members[0], ctx))
elif root is Callable or root is typing.Callable:
if len(members) == 2:
args, return_value = members
Expand Down Expand Up @@ -955,11 +964,6 @@ class Pep655Value(Value):
value: Value


@dataclass
class UnpackedValue(Value):
elements: Sequence[Tuple[bool, Value]]


class _Visitor(ast.NodeVisitor):
def __init__(self, ctx: Context) -> None:
self.ctx = ctx
Expand Down Expand Up @@ -1136,7 +1140,7 @@ def _value_of_origin_args(
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
allow_unpack: bool = False,
) -> Value:
if origin is typing.Type or origin is type:
if not args:
Expand All @@ -1151,9 +1155,9 @@ def _value_of_origin_args(
return SequenceValue(tuple, [])
else:
args_vals = [
_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args
_type_from_runtime(arg, ctx, allow_unpack=True) for arg in args
]
return _make_sequence_value(tuple, args_vals)
return _make_sequence_value(tuple, args_vals, ctx)
elif origin is typing.Union:
return unite_values(*[_type_from_runtime(arg, ctx) for arg in args])
elif origin is Callable or origin is typing.Callable:
Expand Down Expand Up @@ -1218,13 +1222,13 @@ def _value_of_origin_args(
return AnyValue(AnySource.error)
return Pep655Value(False, _type_from_runtime(args[0], ctx))
elif is_typing_name(origin, "Unpack"):
if not unpack_allowed:
if not allow_unpack:
ctx.show_error("Invalid usage of Unpack")
return AnyValue(AnySource.error)
if len(args) != 1:
ctx.show_error("Unpack requires a single argument")
return AnyValue(AnySource.error)
return _make_unpacked_value(_type_from_runtime(args[0], ctx), ctx)
return UnpackedValue(_type_from_runtime(args[0], ctx))
elif origin is None and isinstance(val, type):
# This happens for SupportsInt in 3.7.
return _maybe_typed_value(val)
Expand All @@ -1243,27 +1247,22 @@ def _maybe_typed_value(val: Union[type, str]) -> Value:
return TypedValue(val)


def _make_sequence_value(typ: type, members: Sequence[Value]) -> SequenceValue:
def _make_sequence_value(
typ: type, members: Sequence[Value], ctx: Context
) -> SequenceValue:
pairs = []
for val in members:
if isinstance(val, UnpackedValue):
pairs += val.elements
elements = val.get_elements()
if elements is None:
ctx.show_error(f"Invalid usage of Unpack with {val}")
elements = [(True, AnyValue(AnySource.error))]
pairs += elements
else:
pairs.append((False, val))
return SequenceValue(typ, pairs)


def _make_unpacked_value(val: Value, ctx: Context) -> UnpackedValue:
if isinstance(val, SequenceValue) and val.typ is tuple:
return UnpackedValue(val.members)
elif isinstance(val, GenericValue) and val.typ is tuple:
return UnpackedValue([(True, val.args[0])])
elif isinstance(val, TypedValue) and val.typ is tuple:
return UnpackedValue([(True, AnyValue(AnySource.generic_argument))])
ctx.show_error(f"Invalid argument for Unpack: {val}")
return UnpackedValue([])


def _make_callable_from_value(
args: Value, return_value: Value, ctx: Context, is_asynq: bool = False
) -> Value:
Expand All @@ -1280,15 +1279,13 @@ def _make_callable_from_value(
annotation = _type_from_value(arg, ctx)
if is_many:
param = SigParameter(
f"__arg{i}",
f"@{i}",
kind=ParameterKind.VAR_POSITIONAL,
annotation=GenericValue(tuple, [annotation]),
)
else:
param = SigParameter(
f"__arg{i}",
kind=ParameterKind.POSITIONAL_ONLY,
annotation=annotation,
f"@{i}", kind=ParameterKind.POSITIONAL_ONLY, annotation=annotation
)
params.append(param)
try:
Expand Down Expand Up @@ -1318,7 +1315,7 @@ def _make_callable_from_value(
annotations = [_type_from_value(arg, ctx) for arg in args.members]
params = [
SigParameter(
f"__arg{i}",
f"@{i}",
kind=ParameterKind.PARAM_SPEC
if i == len(annotations) - 1
else ParameterKind.POSITIONAL_ONLY,
Expand Down
11 changes: 5 additions & 6 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from .functions import translate_vararg_type
from .options import Options, PyObjectSequenceOption
from .analysis_lib import is_positional_only_arg_name
from .extensions import CustomCheck, TypeGuard, get_overloads, get_type_evaluations
Expand Down Expand Up @@ -423,14 +424,12 @@ def _get_type_for_parameter(
is_constructor: bool,
) -> Value:
if parameter.annotation is not inspect.Parameter.empty:
kind = ParameterKind(parameter.kind)
ctx = AnnotationsContext(self, func_globals)
typ = type_from_runtime(
parameter.annotation, ctx=AnnotationsContext(self, func_globals)
parameter.annotation, ctx=ctx, allow_unpack=kind.allow_unpack()
)
if parameter.kind is inspect.Parameter.VAR_POSITIONAL:
return GenericValue(tuple, [typ])
elif parameter.kind is inspect.Parameter.VAR_KEYWORD:
return GenericValue(dict, [TypedValue(str), typ])
return typ
return translate_vararg_type(kind, typ, self.ctx)
# If this is the self argument of a method, try to infer the self type.
elif index == 0 and parameter.kind in (
inspect.Parameter.POSITIONAL_ONLY,
Expand Down
Loading

0 comments on commit 01ae202

Please sign in to comment.