Skip to content

Commit

Permalink
Correctly check the self argument to @property getters (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Mar 22, 2022
1 parent d50a15b commit 345d374
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 71 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Correctly check the `self` argument to `@property` getters (#506)
- Correctly track assignments of variables inside `try` blocks
and inside `with` blocks that may suppress exceptions (#504)
- Support mappings that do not inherit from `collections.abc.Mapping`
Expand Down
11 changes: 0 additions & 11 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Impl,
MaybeSignature,
OverloadedSignature,
PropertyArgSpec,
make_bound_method,
SigParameter,
Signature,
Expand Down Expand Up @@ -792,16 +791,6 @@ def _uncached_get_argspec(
# these with inspect, so just give up.
return self._make_any_sig(obj)

if isinstance(obj, property):
# If we know the getter, inherit its return value.
if obj.fget:
fget_argspec = self._cached_get_argspec(
obj.fget, impl, is_asynq, in_overload_resolution
)
if fget_argspec is not None and fget_argspec.has_return_value():
return PropertyArgSpec(obj, return_value=fget_argspec.return_value)
return PropertyArgSpec(obj)

return None

def _make_any_sig(self, obj: object) -> Signature:
Expand Down
2 changes: 1 addition & 1 deletion pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def record_usage(self, obj: Any, val: Value) -> None:
def record_attr_read(self, obj: Any) -> None:
pass

def get_property_type_from_argspec(self, obj: Any) -> Value:
def get_property_type_from_argspec(self, obj: property) -> Value:
return AnyValue(AnySource.inference)

def get_attribute_from_typeshed(self, typ: type, *, on_class: bool) -> Value:
Expand Down
48 changes: 28 additions & 20 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
is_union,
kv_pairs_from_mapping,
make_weak,
set_self,
unannotate_value,
unite_and_simplify,
unite_values,
Expand Down Expand Up @@ -324,6 +325,7 @@ class _AttrContext(CheckerAttrContext):
visitor: "NameCheckVisitor"
node: Optional[ast.AST]
ignore_none: bool = False
record_reads: bool = True

# Needs to be implemented explicitly to work around Cython limitations
def __init__(
Expand All @@ -332,11 +334,12 @@ def __init__(
attr: str,
visitor: "NameCheckVisitor",
*,
node: Optional[ast.AST] = None,
node: Optional[ast.AST],
ignore_none: bool = False,
skip_mro: bool = False,
skip_unwrap: bool = False,
prefer_typeshed: bool = False,
record_reads: bool = True,
) -> None:
super().__init__(
root_composite,
Expand All @@ -350,25 +353,21 @@ def __init__(
self.node = node
self.visitor = visitor
self.ignore_none = ignore_none
self.record_reads = record_reads

def record_usage(self, obj: object, val: Value) -> None:
self.visitor._maybe_record_usage(obj, self.attr, val)

def record_attr_read(self, obj: type) -> None:
if self.node is not None:
if self.record_reads and self.node is not None:
self.visitor._record_type_attr_read(obj, self.attr, self.node)

def get_property_type_from_argspec(self, obj: object) -> Value:
argspec = self.visitor.arg_spec_cache.get_argspec(obj)
if argspec is not None:
if argspec.has_return_value():
return argspec.return_value
# If we visited the property and inferred a return value,
# use it.
local = self.visitor.get_local_return_value(argspec)
if local is not None:
return local
return AnyValue(AnySource.inference)
def get_property_type_from_argspec(self, obj: property) -> Value:
if obj.fget is None:
return UNINITIALIZED_VALUE

getter = set_self(KnownValue(obj.fget), self.root_composite.value)
return self.visitor.check_call(self.node, getter, [self.root_composite])

def should_ignore_none_attributes(self) -> bool:
return self.ignore_none
Expand Down Expand Up @@ -1372,8 +1371,10 @@ def _check_for_incompatible_overrides(
Composite(base_class_value),
varname,
self,
node=node,
skip_mro=True,
skip_unwrap=True,
record_reads=False,
)
base_value = attributes.get_attribute(ctx)
can_assign = self._can_assign_to_base(base_value, value)
Expand Down Expand Up @@ -1723,6 +1724,12 @@ def _set_argspec_to_retval(
if isinstance(info.node, ast.AsyncFunctionDef) or info.is_decorated_coroutine:
return_value = GenericValue(collections.abc.Awaitable, [return_value])

if isinstance(val, KnownValue) and isinstance(val.val, property):
fget = val.val.fget
if fget is None:
return
val = KnownValue(fget)

sig = self.signature_from_value(val)
if sig is None or sig.has_return_value():
return
Expand Down Expand Up @@ -4407,7 +4414,7 @@ def _can_perform_call(

def check_call(
self,
node: ast.AST,
node: Optional[ast.AST],
callee: Value,
args: Iterable[Composite],
keywords: Iterable[Tuple[Optional[str], Composite]] = (),
Expand All @@ -4430,7 +4437,7 @@ def check_call(

def _check_call_no_mvv(
self,
node: ast.AST,
node: Optional[ast.AST],
callee_wrapped: Value,
args: Iterable[Composite],
keywords: Iterable[Tuple[Optional[str], Composite]] = (),
Expand All @@ -4451,11 +4458,12 @@ def _check_call_no_mvv(
return_value = AnyValue(AnySource.from_another)

elif extended_argspec is None:
self._show_error_if_checking(
node,
f"{callee_wrapped} is not callable",
error_code=ErrorCode.not_callable,
)
if node is not None:
self._show_error_if_checking(
node,
f"{callee_wrapped} is not callable",
error_code=ErrorCode.not_callable,
)
return_value = AnyValue(AnySource.error)

else:
Expand Down
71 changes: 34 additions & 37 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
Tuple,
TYPE_CHECKING,
)
from typing_extensions import Literal, Protocol, Self
from typing_extensions import Literal, Protocol, Self, assert_never

if TYPE_CHECKING:
from .name_check_visitor import NameCheckVisitor
Expand Down Expand Up @@ -183,7 +183,7 @@ def on_error(
@dataclass
class _VisitorBasedContext:
visitor: "NameCheckVisitor"
node: ast.AST
node: Optional[ast.AST]

@property
def can_assign_ctx(self) -> CanAssignContext:
Expand All @@ -197,7 +197,11 @@ def on_error(
node: Optional[ast.AST] = None,
detail: Optional[str] = ...,
) -> None:
self.visitor.show_error(node or self.node, message, code, detail=detail)
if node is None:
node = self.node
if node is None:
return
self.visitor.show_error(node, message, code, detail=detail)


@dataclass
Expand Down Expand Up @@ -278,7 +282,7 @@ class CallContext:
"""Using the visitor can allow various kinds of advanced logic
in impl functions."""
composites: Dict[str, Composite]
node: ast.AST
node: Optional[ast.AST]
"""AST node corresponding to the function call. Useful for
showing errors."""

Expand Down Expand Up @@ -1040,7 +1044,10 @@ def get_default_return(self, source: AnySource = AnySource.error) -> CallReturn:
return CallReturn(return_value, is_error=True)

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
self,
args: Iterable[Argument],
visitor: "NameCheckVisitor",
node: Optional[ast.AST],
) -> Value:
"""Type check a call to this Signature with the given arguments.
Expand Down Expand Up @@ -1586,9 +1593,15 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Signature":
params.append((name, param))
else:
params.append((name, param.substitute_typevars(typevars)))
params_dict = dict(params)
return_value = self.return_value.substitute_typevars(typevars)
# Returning the same object helps the local return value check, which relies
# on identity of signature objects.
if return_value == self.return_value and params_dict == self.parameters:
return self
return Signature(
dict(params),
self.return_value.substitute_typevars(typevars),
params_dict,
return_value,
impl=self.impl,
callable=self.callable,
is_asynq=self.is_asynq,
Expand Down Expand Up @@ -1998,7 +2011,10 @@ def __init__(self, sigs: Sequence[Signature]) -> None:
object.__setattr__(self, "signatures", tuple(sigs))

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
self,
args: Iterable[Argument],
visitor: "NameCheckVisitor",
node: Optional[ast.AST],
) -> Value:
"""Check a call to an overloaded function.
Expand Down Expand Up @@ -2179,9 +2195,10 @@ def _make_detail(
return CanAssignError(children=details)

def substitute_typevars(self, typevars: TypeVarMap) -> "OverloadedSignature":
return OverloadedSignature(
[sig.substitute_typevars(typevars) for sig in self.signatures]
)
new_sigs = [sig.substitute_typevars(typevars) for sig in self.signatures]
if all(sig1 is sig2 for sig1, sig2 in zip(self.signatures, new_sigs)):
return self
return OverloadedSignature(new_sigs)

def bind_self(
self,
Expand Down Expand Up @@ -2255,7 +2272,10 @@ class BoundMethodSignature:
return_override: Optional[Value] = None

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
self,
args: Iterable[Argument],
visitor: "NameCheckVisitor",
node: Optional[ast.AST],
) -> Value:
ret = self.signature.check_call(
[(self.self_composite, None), *args], visitor, node
Expand Down Expand Up @@ -2299,30 +2319,7 @@ def __str__(self) -> str:
return f"{self.signature} bound to {self.self_composite.value}"


@dataclass(frozen=True)
class PropertyArgSpec:
"""Pseudo-argspec for properties."""

obj: object
return_value: Value = AnyValue(AnySource.unannotated)

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
) -> Value:
raise TypeError("property object is not callable")

def has_return_value(self) -> bool:
return not isinstance(self.return_value, AnyValue)

def substitute_typevars(self, typevars: TypeVarMap) -> "PropertyArgSpec":
return PropertyArgSpec(
self.obj, self.return_value.substitute_typevars(typevars)
)


MaybeSignature = Union[
None, Signature, BoundMethodSignature, PropertyArgSpec, OverloadedSignature
]
MaybeSignature = Union[None, Signature, BoundMethodSignature, OverloadedSignature]


def make_bound_method(
Expand All @@ -2339,7 +2336,7 @@ def make_bound_method(
return_override = argspec.return_override
return BoundMethodSignature(argspec.signature, self_composite, return_override)
else:
assert False, f"invalid argspec {argspec}"
assert_never(argspec)


T = TypeVar("T")
Expand Down
41 changes: 41 additions & 0 deletions pyanalyze/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def test_property(self):
def capybara(uid):
assert_is_value(PropertyObject(uid).string_property, TypedValue(str))

@assert_passes()
def test_local_return(self):
class X:
@property
def foo(self):
return str(1)

def capybara() -> None:
assert_is_value(X().foo, TypedValue(str))


class TestShadowing(TestNameCheckVisitorBase):
@assert_passes()
Expand Down Expand Up @@ -1108,3 +1118,34 @@ def wrapper():
assert_type(func(1), int)
assert_type(func(1, 1), int)
assert_type(func("x"), float)


class TestSelfAnnotation(TestNameCheckVisitorBase):
@assert_passes()
def test_method(self):
from typing import Generic, TypeVar

T = TypeVar("T")

class Capybara(Generic[T]):
def method(self: "Capybara[int]") -> int:
return 1

def caller(ci: Capybara[int], cs: Capybara[str]):
assert_is_value(ci.method(), TypedValue(int))
cs.method() # E: incompatible_argument

@assert_passes()
def test_property(self):
from typing import Generic, TypeVar

T = TypeVar("T")

class Capybara(Generic[T]):
@property
def prop(self: "Capybara[int]") -> int:
return 1

def caller(ci: Capybara[int], cs: Capybara[str]):
assert_is_value(ci.prop, TypedValue(int))
cs.prop # E: incompatible_argument
2 changes: 0 additions & 2 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,6 @@ def get_signature(
return None
if isinstance(signature, pyanalyze.signature.BoundMethodSignature):
signature = signature.get_signature(ctx=ctx)
if isinstance(signature, pyanalyze.signature.PropertyArgSpec):
return None
return signature

def substitute_typevars(self, typevars: TypeVarMap) -> "Value":
Expand Down

0 comments on commit 345d374

Please sign in to comment.