Skip to content

Commit

Permalink
Remove WeakExtension (#517)
Browse files Browse the repository at this point in the history
With the new support for variable-length elements in SequenceValue,
we can use a SequenceValue with a single variable-length element
instead of a "weak" list. The behavior is roughly equivalent, but allows
more precise inference in some cases.

We add an impl for dict.pop to compensate for some of this more
precise inference.
  • Loading branch information
JelleZijlstra authored Apr 13, 2022
1 parent a6ba157 commit 40065df
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 205 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Add implementation function for `dict.pop` (#517)
- Remove `WeakExtension` (#517)
- Fix propagation of no-return-unless constraints from calls
to unions (#518)
- Initial support for variable-length heterogeneous sequences
Expand Down
3 changes: 1 addition & 2 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
Value,
TypeVarValue,
extract_typevars,
make_weak,
)
import pyanalyze

Expand Down Expand Up @@ -206,7 +205,7 @@ class FunctionsSafeToCall(PyObjectSequenceOption[object]):
arguments."""

name = "functions_safe_to_call"
default_value = [sorted, asynq.asynq, make_weak]
default_value = [sorted, asynq.asynq]


_HookReturn = Union[None, ConcreteSignature, inspect.Signature, Callable[..., Any]]
Expand Down
181 changes: 112 additions & 69 deletions pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@
MultiValuedValue,
KNOWN_MUTABLE_TYPES,
Value,
WeakExtension,
check_hashability,
concrete_values_from_iterable,
kv_pairs_from_mapping,
make_weak,
unannotate,
unite_values,
flatten_values,
Expand Down Expand Up @@ -365,15 +363,8 @@ def _list_append_impl(ctx: CallContext) -> ImplReturn:
)
return ImplReturn(KnownValue(None), no_return_unless=no_return_unless)
elif isinstance(lst, GenericValue):
return _maybe_broaden_weak_type(
"list.append",
"object",
ctx.vars["self"],
lst,
element,
ctx,
list,
varname,
return _check_generic_container(
"list.append", "object", ctx.vars["self"], lst, element, ctx, list
)
return ImplReturn(KnownValue(None))

Expand Down Expand Up @@ -659,6 +650,90 @@ def inner(key: Value) -> Value:
return flatten_unions(inner, ctx.vars["key"])


def _dict_pop_impl(ctx: CallContext) -> ImplReturn:
key = ctx.vars["key"]
default = ctx.vars["default"]
varname = ctx.visitor.varname_for_self_constraint(ctx.node)
self_value = replace_known_sequence_value(ctx.vars["self"])

if not _check_dict_key_hashability(key, ctx, "key"):
return ImplReturn(AnyValue(AnySource.error))

if isinstance(self_value, TypedDictValue):
if not TypedValue(str).is_assignable(key, ctx.visitor):
ctx.show_error(
f"TypedDict key must be str, not {key}",
ErrorCode.invalid_typeddict_key,
arg="key",
)
return ImplReturn(AnyValue(AnySource.error))
elif isinstance(key, KnownValue):
try:
is_required, expected_type = self_value.items[key.val]
# probably KeyError, but catch anything in case it's an
# unhashable str subclass or something
except Exception:
pass
else:
if is_required:
ctx.show_error(
f"Cannot pop required TypedDict key {key}",
error_code=ErrorCode.incompatible_argument,
arg="key",
)
return ImplReturn(_maybe_unite(expected_type, default))
ctx.show_error(
f"Key {key} does not exist in TypedDict",
ErrorCode.invalid_typeddict_key,
arg="key",
)
return ImplReturn(default)
elif isinstance(self_value, DictIncompleteValue):
existing_value = self_value.get_value(key, ctx.visitor)
is_present = existing_value is not UNINITIALIZED_VALUE
if varname is not None and isinstance(key, KnownValue):
new_value = DictIncompleteValue(
self_value.typ,
[pair for pair in self_value.kv_pairs if pair.key != key],
)
no_return_unless = Constraint(
varname, ConstraintType.is_value_object, True, new_value
)
else:
no_return_unless = NULL_CONSTRAINT
if not is_present:
if default is _NO_ARG_SENTINEL:
ctx.show_error(
f"Key {key} does not exist in dictionary {self_value}",
error_code=ErrorCode.incompatible_argument,
arg="key",
)
return ImplReturn(AnyValue(AnySource.error))
return ImplReturn(default, no_return_unless=no_return_unless)
return ImplReturn(
_maybe_unite(existing_value, default), no_return_unless=no_return_unless
)
elif isinstance(self_value, TypedValue):
key_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 0)
value_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 1)
tv_map = key_type.can_assign(key, ctx.visitor)
if isinstance(tv_map, CanAssignError):
ctx.show_error(
f"Key {key} is not valid for {self_value}",
ErrorCode.incompatible_argument,
arg="key",
)
return ImplReturn(_maybe_unite(value_type, default))
else:
return ImplReturn(AnyValue(AnySource.inference))


def _maybe_unite(value: Value, default: Value) -> Value:
if default is _NO_ARG_SENTINEL:
return value
return unite_values(value, default)


def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn:
key = ctx.vars["key"]
default = ctx.vars["default"]
Expand Down Expand Up @@ -721,27 +796,14 @@ def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn:
key_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 0)
value_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 1)
new_value_type = unite_values(value_type, default)
if _is_weak(ctx.vars["self"]):
new_key_type = unite_values(key_type, key)
new_type = make_weak(
GenericValue(self_value.typ, [new_key_type, new_value_type])
tv_map = key_type.can_assign(key, ctx.visitor)
if isinstance(tv_map, CanAssignError):
ctx.show_error(
f"Key {key} is not valid for {self_value}",
ErrorCode.incompatible_argument,
arg="key",
)
if varname is not None:
no_return_unless = Constraint(
varname, ConstraintType.is_value_object, True, new_type
)
else:
no_return_unless = NULL_CONSTRAINT
return ImplReturn(new_value_type, no_return_unless=no_return_unless)
else:
tv_map = key_type.can_assign(key, ctx.visitor)
if isinstance(tv_map, CanAssignError):
ctx.show_error(
f"Key {key} is not valid for {self_value}",
ErrorCode.incompatible_argument,
arg="key",
)
return ImplReturn(new_value_type)
return ImplReturn(new_value_type)
else:
return ImplReturn(AnyValue(AnySource.inference))

Expand All @@ -767,7 +829,7 @@ def _unpack_iterable_of_pairs(
return kv_pairs


def _weak_dict_update(
def _update_incomplete_dict(
self_val: Value,
pairs: Sequence[KVPair],
ctx: CallContext,
Expand Down Expand Up @@ -799,17 +861,13 @@ def _add_pairs_to_dict(
ctx: CallContext,
varname: Optional[VarnameWithOrigin],
) -> ImplReturn:
if _is_weak(self_val):
return _weak_dict_update(self_val, pairs, ctx, varname)

# Now we don't care about Annotated
self_val = replace_known_sequence_value(self_val)
if isinstance(self_val, TypedDictValue):
for pair in pairs:
_typeddict_setitem(self_val, pair.key, pair.value, ctx)
return ImplReturn(KnownValue(None))
elif isinstance(self_val, DictIncompleteValue):
return _weak_dict_update(self_val, pairs, ctx, varname)
return _update_incomplete_dict(self_val, pairs, ctx, varname)
elif isinstance(self_val, TypedValue):
key_type = self_val.get_generic_arg_for_type(dict, ctx.visitor, 0)
value_type = self_val.get_generic_arg_for_type(dict, ctx.visitor, 1)
Expand All @@ -832,7 +890,7 @@ def _add_pairs_to_dict(
)
return ImplReturn(KnownValue(None))
else:
return _weak_dict_update(self_val, pairs, ctx, varname)
return _update_incomplete_dict(self_val, pairs, ctx, varname)


def _dict_update_impl(ctx: CallContext) -> ImplReturn:
Expand Down Expand Up @@ -949,15 +1007,14 @@ def inner(lst: Value, iterable: Value) -> ImplReturn:
actual_type = iterable.get_generic_arg_for_type(
collections.abc.Iterable, ctx.visitor, 0
)
return _maybe_broaden_weak_type(
return _check_generic_container(
name,
iterable_arg,
lst,
cleaned_lst,
actual_type,
ctx,
list,
varname,
return_container=return_container,
)
return ImplReturn(lst if return_container else KnownValue(None))
Expand All @@ -973,33 +1030,18 @@ def _list_iadd_impl(ctx: CallContext) -> ImplReturn:
return _list_extend_or_iadd_impl(ctx, "x", "list.__iadd__", return_container=True)


def _is_weak(val: Value) -> bool:
return isinstance(val, AnnotatedValue) and val.has_metadata_of_type(WeakExtension)


def _maybe_broaden_weak_type(
def _check_generic_container(
function_name: str,
arg: str,
original_container_type: Value,
container_type: Value,
container_type: GenericValue,
actual_type: Value,
ctx: CallContext,
typ: type,
varname: VarnameWithOrigin,
*,
return_container: bool = False,
) -> ImplReturn:
expected_type = container_type.get_generic_arg_for_type(typ, ctx.visitor, 0)
if _is_weak(original_container_type):
generic_arg = unite_values(expected_type, actual_type)
constrained_value = make_weak(GenericValue(typ, [generic_arg]))
no_return_unless = Constraint(
varname, ConstraintType.is_value_object, True, constrained_value
)
if return_container:
return ImplReturn(constrained_value)
return ImplReturn(KnownValue(None), no_return_unless=no_return_unless)

tv_map = expected_type.can_assign(actual_type, ctx.visitor)
if isinstance(tv_map, CanAssignError):
ctx.show_error(
Expand Down Expand Up @@ -1029,15 +1071,8 @@ def _set_add_impl(ctx: CallContext) -> ImplReturn:
)
return ImplReturn(KnownValue(None), no_return_unless=no_return_unless)
elif isinstance(set_value, GenericValue):
return _maybe_broaden_weak_type(
"set.add",
"object",
ctx.vars["self"],
set_value,
element,
ctx,
set,
varname,
return _check_generic_container(
"set.add", "object", ctx.vars["self"], set_value, element, ctx, set
)
return ImplReturn(KnownValue(None))

Expand Down Expand Up @@ -1533,6 +1568,15 @@ def get_default_argspecs() -> Dict[object, Signature]:
callable=dict.setdefault,
impl=_dict_setdefault_impl,
),
Signature.make(
[
SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)),
SigParameter("key", _POS_ONLY),
SigParameter("default", _POS_ONLY, default=_NO_ARG_SENTINEL),
],
callable=dict.pop,
impl=_dict_pop_impl,
),
Signature.make(
[
SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)),
Expand All @@ -1551,9 +1595,8 @@ def get_default_argspecs() -> Dict[object, Signature]:
annotation=GenericValue(dict, [TypeVarValue(K), TypeVarValue(V)]),
)
],
AnnotatedValue(
GenericValue(dict, [TypeVarValue(K), TypeVarValue(V)]),
[WeakExtension()],
DictIncompleteValue(
dict, [KVPair(TypeVarValue(K), TypeVarValue(V), is_many=True)]
),
callable=dict.copy,
),
Expand Down
12 changes: 6 additions & 6 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@
get_tv_map,
is_union,
kv_pairs_from_mapping,
make_weak,
set_self,
unannotate_value,
unite_and_simplify,
Expand Down Expand Up @@ -2468,10 +2467,9 @@ def _visit_comprehension_inner(
detail=str(hashability),
)
key_value = AnyValue(AnySource.error)
if isinstance(key_value, AnyValue) and isinstance(value_value, AnyValue):
return TypedValue(dict)
else:
return make_weak(GenericValue(dict, [key_value, value_value]))
return DictIncompleteValue(
dict, [KVPair(key_value, value_value, is_many=True)]
)

with qcore.override(self, "in_comprehension_body", True):
member_value = self.visit(node.elt)
Expand All @@ -2489,7 +2487,9 @@ def _visit_comprehension_inner(

if typ is types.GeneratorType:
return GenericValue(typ, [member_value, KnownValue(None), KnownValue(None)])
return make_weak(GenericValue(typ, [member_value]))
# Returning a SequenceValue here instead of a GenericValue allows
# later code to modify this container.
return SequenceValue(typ, [(True, member_value)])

# Literals and displays

Expand Down
13 changes: 10 additions & 3 deletions pyanalyze/test_async_await.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# static analysis: ignore
from .tests import make_simple_sequence
from .value import GenericValue, KnownValue, TypedValue, make_weak, AnyValue, AnySource
from .value import (
GenericValue,
KnownValue,
SequenceValue,
TypedValue,
AnyValue,
AnySource,
)
from .implementation import assert_is_value
from .test_node_visitor import assert_passes, only_before
from .test_name_check_visitor import TestNameCheckVisitorBase
Expand Down Expand Up @@ -62,7 +69,7 @@ def __aiter__(self) -> ANext:

async def f():
x = [y async for y in AIter()]
assert_is_value(x, make_weak(GenericValue(list, [TypedValue(int)])))
assert_is_value(x, SequenceValue(list, [(True, TypedValue(int))]))

@assert_passes()
def test_async_generator(self):
Expand All @@ -82,7 +89,7 @@ async def capybara():
# TODO should be list[int] but we lose the type argument somewhere
assert_is_value(
ints,
make_weak(GenericValue(list, [AnyValue(AnySource.generic_argument)])),
SequenceValue(list, [(True, AnyValue(AnySource.generic_argument))]),
)

@assert_passes()
Expand Down
Loading

0 comments on commit 40065df

Please sign in to comment.