Skip to content

Commit

Permalink
dict.get impl (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Feb 5, 2022
1 parent 39fd292 commit 11c4d2f
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add plugin providing a precise type for `dict.get` calls (#460)
- Fix internal error when an `__eq__` method throws (#461)
- Fix handling of `async def` methods in stubs (#459)
- Treat Thrift enums as compatible with protocols that
Expand Down
84 changes: 84 additions & 0 deletions pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,81 @@ def inner(key: Value) -> Value:
return flatten_unions(inner, ctx.vars["k"])


def _dict_get_impl(ctx: CallContext) -> ImplReturn:
default = ctx.vars["default"]

def inner(key: Value) -> Value:
self_value = ctx.vars["self"]
if isinstance(self_value, AnnotatedValue):
self_value = self_value.value
if isinstance(key, KnownValue):
try:
hash(key.val)
except Exception:
ctx.show_error(
f"Dictionary key {key} is not hashable",
ErrorCode.unhashable_key,
arg="k",
)
return AnyValue(AnySource.error)
if isinstance(self_value, KnownValue):
if isinstance(key, KnownValue):
try:
return_value = self_value.val[key.val]
except Exception:
return default
else:
return KnownValue(return_value) | default
# else just treat it together with DictIncompleteValue
self_value = replace_known_sequence_value(self_value)
if isinstance(self_value, TypedDictValue):
if not TypedValue(str).is_assignable(key, ctx.visitor):
ctx.show_error(
f"TypedDict key must be str, not {key}",
ErrorCode.invalid_typeddict_key,
arg="k",
)
return AnyValue(AnySource.error)
elif isinstance(key, KnownValue):
try:
required, value = self_value.items[key.val]
# probably KeyError, but catch anything in case it's an
# unhashable str subclass or something
except Exception:
# No error here; TypedDicts may have additional keys at runtime.
pass
else:
if required:
return value
else:
return value | default
# TODO strictly we should throw an error for any non-Literal or unknown key:
# https://www.python.org/dev/peps/pep-0589/#supported-and-unsupported-operations
# Don't do that yet because it may cause too much disruption.
return default
elif isinstance(self_value, DictIncompleteValue):
val = self_value.get_value(key, ctx.visitor)
if val is UNINITIALIZED_VALUE:
return default
return val | default
elif isinstance(self_value, TypedValue):
key_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 0)
can_assign = key_type.can_assign(key, ctx.visitor)
if isinstance(can_assign, CanAssignError):
ctx.show_error(
f"Dictionary does not accept keys of type {key}",
error_code=ErrorCode.incompatible_argument,
detail=str(can_assign),
arg="key",
)
value_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 1)
return value_type | default
else:
return AnyValue(AnySource.inference)

return flatten_unions(inner, ctx.vars["key"])


def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn:
key = ctx.vars["key"]
default = ctx.vars["default"]
Expand Down Expand Up @@ -1370,6 +1445,15 @@ def get_default_argspecs() -> Dict[object, Signature]:
callable=dict.__getitem__,
impl=_dict_getitem_impl,
),
Signature.make(
[
SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)),
SigParameter("key", _POS_ONLY),
SigParameter("default", _POS_ONLY, default=KnownValue(None)),
],
callable=dict.get,
impl=_dict_get_impl,
),
Signature.make(
[
SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)),
Expand Down
31 changes: 31 additions & 0 deletions pyanalyze/test_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,37 @@ def capybara(cond):
),
)

@assert_passes()
def test_dict_get(self):
from typing_extensions import TypedDict, NotRequired
from typing import Dict

class TD(TypedDict):
a: int
b: str
c: NotRequired[str]

def capybara(td: TD, s: str, d: Dict[str, int]):
assert_is_value(td.get("a"), TypedValue(int))
assert_is_value(td.get("c"), TypedValue(str) | KnownValue(None))
assert_is_value(td.get("c", 1), TypedValue(str) | KnownValue(1))
td.get(1) # E: invalid_typeddict_key

known = {"a": "b"}
assert_is_value(known.get("a"), KnownValue("b") | KnownValue(None))
assert_is_value(known.get("b", 1), KnownValue(1))
assert_is_value(known.get(s), KnownValue("b") | KnownValue(None))

incomplete = {**td, "b": 1, "d": s}
assert_is_value(incomplete.get("a"), TypedValue(int) | KnownValue(None))
assert_is_value(incomplete.get("b"), KnownValue(1) | KnownValue(None))
assert_is_value(incomplete.get("d"), TypedValue(str) | KnownValue(None))
assert_is_value(incomplete.get("e"), KnownValue(None))

assert_is_value(d.get("x"), TypedValue(int) | KnownValue(None))
assert_is_value(d.get(s), TypedValue(int) | KnownValue(None))
d.get(1) # E: incompatible_argument

@assert_passes()
def test_setdefault(self):
from typing_extensions import TypedDict
Expand Down

0 comments on commit 11c4d2f

Please sign in to comment.