Skip to content

Commit

Permalink
Add @has_extra_keys (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Nov 14, 2022
1 parent a6cc64f commit 2c93b3d
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add experimental `@has_extra_keys` decorator for `TypedDictt` types (#568)
- Fix crash on recursive type aliases. Recursive type aliases now fall back to `Any` (#565)
- Support `in` on objects with only `__getitem__` (#564)
- Add support for `except*` (PEP 654) (#562)
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
used(extensions.is_positional)
used(extensions.is_of_type)
used(extensions.show_error)
used(extensions.has_extra_keys)
used(value.UNRESOLVED_VALUE) # keeping it around for now just in case
used(reexport)
used(patma)
Expand Down
7 changes: 6 additions & 1 deletion pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,16 @@ def _type_from_runtime(
# inheritance, this makes it apparently impossible to figure out which
# keys are required at runtime.
total = getattr(val, "__total__", True)
if hasattr(val, "__extra_keys__"):
extra_keys = _type_from_runtime(val.__extra_keys__, ctx)
else:
extra_keys = None
return TypedDictValue(
{
key: _get_typeddict_value(value, ctx, key, required_keys, total)
for key, value in val.__annotations__.items()
}
},
extra_keys=extra_keys,
)
elif val is InitVar:
# On 3.6 and 3.7, InitVar[T] just returns InitVar at runtime, so we can't
Expand Down
9 changes: 9 additions & 0 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,15 @@ def _uncached_get_argspec(
)
for key, (required, value) in td_type.items.items()
]
if td_type.extra_keys is not None:
annotation = GenericValue(
dict, [TypedValue(str), td_type.extra_keys]
)
params.append(
SigParameter(
"%kwargs", ParameterKind.VAR_KEYWORD, annotation=annotation
)
)
return Signature.make(params, td_type)

if is_newtype(obj):
Expand Down
25 changes: 25 additions & 0 deletions pyanalyze/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,28 @@ def show_error(message: str, *, argument: Optional[Any] = None) -> bool:
raise NotImplementedError(
"show_error() may only be called in type evaluation functions"
)


def has_extra_keys(value_type: object = Any) -> Callable[[_T], _T]:
"""Decorator for ``TypedDict`` types, indicating that the dict
has additional keys of the given type.
This is an experimental feature.
Example usage::
@has_extra_keys(str)
class TD(TypedDict):
a: int
def f(x: TD) -> None:
assert_type(x["a"], int)
assert_type(x["arbitrary_key"], str)
"""

def decorator(cls: _T) -> _T:
cls.__extra_keys__ = value_type
return cls

return decorator
41 changes: 27 additions & 14 deletions pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,22 +479,27 @@ def _typeddict_setitem(
ErrorCode.invalid_typeddict_key,
arg="k",
)
elif key.val not in self_value.items:
ctx.show_error(
f"Key {key.val!r} does not exist in {self_value}",
ErrorCode.invalid_typeddict_key,
arg="k",
)
else:
_, expected_type = self_value.items[key.val]
tv_map = expected_type.can_assign(value, ctx.visitor)
if isinstance(tv_map, CanAssignError):
return
if key.val not in self_value.items:
if self_value.extra_keys is None:
ctx.show_error(
f"Value for key {key.val!r} must be {expected_type}, not {value}",
ErrorCode.incompatible_argument,
arg="v",
detail=str(tv_map),
f"Key {key.val!r} does not exist in {self_value}",
ErrorCode.invalid_typeddict_key,
arg="k",
)
return
else:
expected_type = self_value.extra_keys
else:
_, expected_type = self_value.items[key.val]
tv_map = expected_type.can_assign(value, ctx.visitor)
if isinstance(tv_map, CanAssignError):
ctx.show_error(
f"Value for key {key.val!r} must be {expected_type}, not {value}",
ErrorCode.incompatible_argument,
arg="v",
detail=str(tv_map),
)


def _check_dict_key_hashability(key: Value, ctx: CallContext, arg: str) -> bool:
Expand Down Expand Up @@ -554,6 +559,8 @@ def inner(key: Value) -> Value:
except Exception:
# No error here; TypedDicts may have additional keys at runtime.
pass
if self_value.extra_keys is not None:
return self_value.extra_keys
# TODO strictly we should throw an error for any non-Literal or unknown key:
# https://www.python.org/dev/peps/pep-0589/#supported-and-unsupported-operations
# Don't do that yet because it may cause too much disruption.
Expand Down Expand Up @@ -622,6 +629,8 @@ def inner(key: Value) -> Value:
return value
else:
return value | default
if self_value.extra_keys is not None:
return self_value.extra_keys | default
# TODO strictly we should throw an error for any non-Literal or unknown key:
# https://www.python.org/dev/peps/pep-0589/#supported-and-unsupported-operations
# Don't do that yet because it may cause too much disruption.
Expand Down Expand Up @@ -681,6 +690,8 @@ def _dict_pop_impl(ctx: CallContext) -> ImplReturn:
arg="key",
)
return ImplReturn(_maybe_unite(expected_type, default))
if self_value.extra_keys is not None:
return ImplReturn(_maybe_unite(self_value.extra_keys, default))
ctx.show_error(
f"Key {key} does not exist in TypedDict",
ErrorCode.invalid_typeddict_key,
Expand Down Expand Up @@ -767,6 +778,8 @@ def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn:
arg="default",
)
return ImplReturn(expected_type)
if self_value.extra_keys is not None:
return ImplReturn(self_value.extra_keys | default)
ctx.show_error(
f"Key {key} does not exist in TypedDict",
ErrorCode.invalid_typeddict_key,
Expand Down
10 changes: 10 additions & 0 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,16 @@ def make(
default=None if is_required else AnyValue(AnySource.marker),
)
i += 1
if param.annotation.extra_keys is not None:
name = f"%kwargs{i}"
param_dict[name] = SigParameter(
name,
ParameterKind.VAR_KEYWORD,
annotation=GenericValue(
dict, [TypedValue(str), param.annotation.extra_keys]
),
)
i += 1
else:
param_dict[param.name] = param
i += 1
Expand Down
118 changes: 118 additions & 0 deletions pyanalyze/test_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,124 @@
from .value import TypedDictValue, TypedValue


class TestExtraKeys(TestNameCheckVisitorBase):
@assert_passes()
def test_signature(self):
from pyanalyze.extensions import has_extra_keys
from typing_extensions import TypedDict

@has_extra_keys(int)
class TD(TypedDict):
a: str

def capybara() -> None:
x = TD(a="a", b=1)
assert_is_value(
x,
TypedDictValue(
{"a": (True, TypedValue(str))}, extra_keys=TypedValue(int)
),
)

TD(a="a", b="b") # E: incompatible_argument

@assert_passes()
def test_methods(self):
from pyanalyze.extensions import has_extra_keys
from typing_extensions import TypedDict, assert_type, Literal
from typing import Union

@has_extra_keys(int)
class TD(TypedDict):
a: str

class NormalTD(TypedDict):
a: str

def getitem(td: TD, ntd: NormalTD) -> None:
td["b"] = 3
ntd["b"] = 3 # E: invalid_typeddict_key

def setitem(td: TD) -> None:
assert_type(td["b"], int)

def get(td: TD) -> None:
assert_type(td.get("b", "x"), Union[int, Literal["x"]])

def pop(td: TD) -> None:
assert_type(td.pop("b"), int)

def setdefault(td: TD) -> None:
assert_type(td.setdefault("b", "x"), Union[int, Literal["x"]])

@assert_passes()
def test_kwargs_annotation(self):
from pyanalyze.extensions import has_extra_keys
from typing_extensions import TypedDict, Unpack, assert_type

@has_extra_keys(int)
class TD(TypedDict):
a: str

def caller(**kwargs: Unpack[TD]) -> None:
assert_type(kwargs["b"], int)

def capybara() -> None:
caller(a="x", b=1)
caller(a="x", b="y") # E: incompatible_argument

@assert_passes()
def test_compatibility(self):
from pyanalyze.extensions import has_extra_keys
from typing_extensions import TypedDict

@has_extra_keys(int)
class TD(TypedDict):
a: str

@has_extra_keys(bool)
class TD2(TypedDict):
a: str

@has_extra_keys(bytes)
class TD3(TypedDict):
a: str

def want_td(td: TD) -> None:
pass

def capybara(td: TD, td2: TD2, td3: TD3) -> None:
want_td(td)
want_td(td2)
want_td(td3) # E: incompatible_argument

@assert_passes()
def test_iteration(self):
from pyanalyze.extensions import has_extra_keys
from typing_extensions import TypedDict, assert_type, Literal
from typing import Union

@has_extra_keys(int)
class TD(TypedDict):
a: str

class TD2(TypedDict):
a: str

def capybara(td: TD, td2: TD2) -> None:
for k, v in td.items():
assert_type(k, Union[str, Literal["a"]])
assert_type(v, Union[int, str])
for k in td:
assert_type(k, Union[str, Literal["a"]])

for k, v in td2.items():
assert_type(k, Literal["a"])
assert_type(v, str)
for k in td2:
assert_type(k, Literal["a"])


class TestTypedDict(TestNameCheckVisitorBase):
@assert_passes()
def test_constructor(self):
Expand Down
59 changes: 50 additions & 9 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,14 +1147,29 @@ class TypedDictValue(GenericValue):
items: Dict[str, Tuple[bool, Value]]
"""The items of the ``TypedDict``. Required items are represented as (True, value) and optional
ones as (False, value)."""
extra_keys: Optional[Value] = None
"""The type of unknown keys, if any."""

def __init__(self, items: Dict[str, Tuple[bool, Value]]) -> None:
def __init__(
self, items: Dict[str, Tuple[bool, Value]], extra_keys: Optional[Value] = None
) -> None:
value_types = []
key_types = []
if items:
value_type = unite_values(*[val for _, val in items.values()])
else:
value_type = AnyValue(AnySource.unreachable)
super().__init__(dict, (TypedValue(str), value_type))
value_types += [val for _, val in items.values()]
key_types += [KnownValue(key) for key in items.keys()]
if extra_keys is not None:
value_types.append(extra_keys)
key_types.append(TypedValue(str))
value_type = (
unite_values(*value_types)
if value_types
else AnyValue(AnySource.unreachable)
)
key_type = unite_values(*key_types) if key_types else TypedValue(str)
super().__init__(dict, (key_type, value_type))
self.items = items
self.extra_keys = extra_keys

def num_required_keys(self) -> int:
return sum(1 for required, _ in self.items.values() if required)
Expand Down Expand Up @@ -1193,6 +1208,14 @@ def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign:
children=[can_assign],
)
bounds_maps.append(can_assign)
# TODO: What if only one of the two has extra keys?
if self.extra_keys is not None and other.extra_keys is not None:
can_assign = self.extra_keys.can_assign(other.extra_keys, ctx)
if isinstance(can_assign, CanAssignError):
return CanAssignError(
"Types for extra keys are incompatible", children=[can_assign]
)
bounds_maps.append(can_assign)
return unify_bounds_maps(bounds_maps)
elif isinstance(other, KnownValue) and isinstance(other.val, dict):
bounds_maps = []
Expand All @@ -1216,7 +1239,10 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "TypedDictValue":
{
key: (is_required, value.substitute_typevars(typevars))
for key, (is_required, value) in self.items.items()
}
},
extra_keys=self.extra_keys.substitute_typevars(typevars)
if self.extra_keys is not None
else None,
)

def __str__(self) -> str:
Expand Down Expand Up @@ -2288,9 +2314,14 @@ def concrete_values_from_iterable(
return value.args[0]
return members
elif isinstance(value, TypedDictValue):
if all(required for required, _ in value.items.items()):
if value.extra_keys is None and all(
required for required, _ in value.items.items()
):
return [KnownValue(key) for key in value.items]
return MultiValuedValue([KnownValue(key) for key in value.items])
possibilities = [KnownValue(key) for key in value.items]
if value.extra_keys is not None:
possibilities.append(TypedValue(str))
return MultiValuedValue(possibilities)
elif isinstance(value, DictIncompleteValue):
if all(pair.is_required and not pair.is_many for pair in value.kv_pairs):
return [pair.key for pair in value.kv_pairs]
Expand Down Expand Up @@ -2371,10 +2402,20 @@ def kv_pairs_from_mapping(
if isinstance(value_val, DictIncompleteValue):
return value_val.kv_pairs
elif isinstance(value_val, TypedDictValue):
return [
pairs = [
KVPair(KnownValue(key), value, is_required=required)
for key, (required, value) in value_val.items.items()
]
if value_val.extra_keys is not None:
pairs.append(
KVPair(
TypedValue(str),
value_val.extra_keys,
is_many=True,
is_required=False,
)
)
return pairs
else:
# Ideally we should only need to check ProtocolMappingValue, but if
# we do that we can't infer the right types for dict, so try the
Expand Down

0 comments on commit 2c93b3d

Please sign in to comment.