Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly check the self argument to @property getters #506

Merged
merged 3 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Correctly check the `self` argument to `@property` getters (#506)
- Correctly track assignments of variables inside `try` blocks
and inside `with` blocks that may suppress exceptions (#504)
- Support mappings that do not inherit from `collections.abc.Mapping`
Expand Down
11 changes: 0 additions & 11 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Impl,
MaybeSignature,
OverloadedSignature,
PropertyArgSpec,
make_bound_method,
SigParameter,
Signature,
Expand Down Expand Up @@ -792,16 +791,6 @@ def _uncached_get_argspec(
# these with inspect, so just give up.
return self._make_any_sig(obj)

if isinstance(obj, property):
# If we know the getter, inherit its return value.
if obj.fget:
fget_argspec = self._cached_get_argspec(
obj.fget, impl, is_asynq, in_overload_resolution
)
if fget_argspec is not None and fget_argspec.has_return_value():
return PropertyArgSpec(obj, return_value=fget_argspec.return_value)
return PropertyArgSpec(obj)

return None

def _make_any_sig(self, obj: object) -> Signature:
Expand Down
2 changes: 1 addition & 1 deletion pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def record_usage(self, obj: Any, val: Value) -> None:
def record_attr_read(self, obj: Any) -> None:
pass

def get_property_type_from_argspec(self, obj: Any) -> Value:
def get_property_type_from_argspec(self, obj: property) -> Value:
return AnyValue(AnySource.inference)

def get_attribute_from_typeshed(self, typ: type, *, on_class: bool) -> Value:
Expand Down
48 changes: 28 additions & 20 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
is_union,
kv_pairs_from_mapping,
make_weak,
set_self,
unannotate_value,
unite_and_simplify,
unite_values,
Expand Down Expand Up @@ -324,6 +325,7 @@ class _AttrContext(CheckerAttrContext):
visitor: "NameCheckVisitor"
node: Optional[ast.AST]
ignore_none: bool = False
record_reads: bool = True

# Needs to be implemented explicitly to work around Cython limitations
def __init__(
Expand All @@ -332,11 +334,12 @@ def __init__(
attr: str,
visitor: "NameCheckVisitor",
*,
node: Optional[ast.AST] = None,
node: Optional[ast.AST],
ignore_none: bool = False,
skip_mro: bool = False,
skip_unwrap: bool = False,
prefer_typeshed: bool = False,
record_reads: bool = True,
) -> None:
super().__init__(
root_composite,
Expand All @@ -350,25 +353,21 @@ def __init__(
self.node = node
self.visitor = visitor
self.ignore_none = ignore_none
self.record_reads = record_reads

def record_usage(self, obj: object, val: Value) -> None:
self.visitor._maybe_record_usage(obj, self.attr, val)

def record_attr_read(self, obj: type) -> None:
if self.node is not None:
if self.record_reads and self.node is not None:
self.visitor._record_type_attr_read(obj, self.attr, self.node)

def get_property_type_from_argspec(self, obj: object) -> Value:
argspec = self.visitor.arg_spec_cache.get_argspec(obj)
if argspec is not None:
if argspec.has_return_value():
return argspec.return_value
# If we visited the property and inferred a return value,
# use it.
local = self.visitor.get_local_return_value(argspec)
if local is not None:
return local
return AnyValue(AnySource.inference)
def get_property_type_from_argspec(self, obj: property) -> Value:
if obj.fget is None:
return UNINITIALIZED_VALUE

getter = set_self(KnownValue(obj.fget), self.root_composite.value)
return self.visitor.check_call(self.node, getter, [self.root_composite])

def should_ignore_none_attributes(self) -> bool:
return self.ignore_none
Expand Down Expand Up @@ -1372,8 +1371,10 @@ def _check_for_incompatible_overrides(
Composite(base_class_value),
varname,
self,
node=node,
skip_mro=True,
skip_unwrap=True,
record_reads=False,
)
base_value = attributes.get_attribute(ctx)
can_assign = self._can_assign_to_base(base_value, value)
Expand Down Expand Up @@ -1723,6 +1724,12 @@ def _set_argspec_to_retval(
if isinstance(info.node, ast.AsyncFunctionDef) or info.is_decorated_coroutine:
return_value = GenericValue(collections.abc.Awaitable, [return_value])

if isinstance(val, KnownValue) and isinstance(val.val, property):
fget = val.val.fget
if fget is None:
return
val = KnownValue(fget)

sig = self.signature_from_value(val)
if sig is None or sig.has_return_value():
return
Expand Down Expand Up @@ -4407,7 +4414,7 @@ def _can_perform_call(

def check_call(
self,
node: ast.AST,
node: Optional[ast.AST],
callee: Value,
args: Iterable[Composite],
keywords: Iterable[Tuple[Optional[str], Composite]] = (),
Expand All @@ -4430,7 +4437,7 @@ def check_call(

def _check_call_no_mvv(
self,
node: ast.AST,
node: Optional[ast.AST],
callee_wrapped: Value,
args: Iterable[Composite],
keywords: Iterable[Tuple[Optional[str], Composite]] = (),
Expand All @@ -4451,11 +4458,12 @@ def _check_call_no_mvv(
return_value = AnyValue(AnySource.from_another)

elif extended_argspec is None:
self._show_error_if_checking(
node,
f"{callee_wrapped} is not callable",
error_code=ErrorCode.not_callable,
)
if node is not None:
self._show_error_if_checking(
node,
f"{callee_wrapped} is not callable",
error_code=ErrorCode.not_callable,
)
return_value = AnyValue(AnySource.error)

else:
Expand Down
71 changes: 34 additions & 37 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
Tuple,
TYPE_CHECKING,
)
from typing_extensions import Literal, Protocol, Self
from typing_extensions import Literal, Protocol, Self, assert_never

if TYPE_CHECKING:
from .name_check_visitor import NameCheckVisitor
Expand Down Expand Up @@ -183,7 +183,7 @@ def on_error(
@dataclass
class _VisitorBasedContext:
visitor: "NameCheckVisitor"
node: ast.AST
node: Optional[ast.AST]

@property
def can_assign_ctx(self) -> CanAssignContext:
Expand All @@ -197,7 +197,11 @@ def on_error(
node: Optional[ast.AST] = None,
detail: Optional[str] = ...,
) -> None:
self.visitor.show_error(node or self.node, message, code, detail=detail)
if node is None:
node = self.node
if node is None:
return
self.visitor.show_error(node, message, code, detail=detail)


@dataclass
Expand Down Expand Up @@ -278,7 +282,7 @@ class CallContext:
"""Using the visitor can allow various kinds of advanced logic
in impl functions."""
composites: Dict[str, Composite]
node: ast.AST
node: Optional[ast.AST]
"""AST node corresponding to the function call. Useful for
showing errors."""

Expand Down Expand Up @@ -1040,7 +1044,10 @@ def get_default_return(self, source: AnySource = AnySource.error) -> CallReturn:
return CallReturn(return_value, is_error=True)

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
self,
args: Iterable[Argument],
visitor: "NameCheckVisitor",
node: Optional[ast.AST],
) -> Value:
"""Type check a call to this Signature with the given arguments.

Expand Down Expand Up @@ -1586,9 +1593,15 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Signature":
params.append((name, param))
else:
params.append((name, param.substitute_typevars(typevars)))
params_dict = dict(params)
return_value = self.return_value.substitute_typevars(typevars)
# Returning the same object helps the local return value check, which relies
# on identity of signature objects.
if return_value == self.return_value and params_dict == self.parameters:
return self
return Signature(
dict(params),
self.return_value.substitute_typevars(typevars),
params_dict,
return_value,
impl=self.impl,
callable=self.callable,
is_asynq=self.is_asynq,
Expand Down Expand Up @@ -1998,7 +2011,10 @@ def __init__(self, sigs: Sequence[Signature]) -> None:
object.__setattr__(self, "signatures", tuple(sigs))

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
self,
args: Iterable[Argument],
visitor: "NameCheckVisitor",
node: Optional[ast.AST],
) -> Value:
"""Check a call to an overloaded function.

Expand Down Expand Up @@ -2179,9 +2195,10 @@ def _make_detail(
return CanAssignError(children=details)

def substitute_typevars(self, typevars: TypeVarMap) -> "OverloadedSignature":
return OverloadedSignature(
[sig.substitute_typevars(typevars) for sig in self.signatures]
)
new_sigs = [sig.substitute_typevars(typevars) for sig in self.signatures]
if all(sig1 is sig2 for sig1, sig2 in zip(self.signatures, new_sigs)):
return self
return OverloadedSignature(new_sigs)

def bind_self(
self,
Expand Down Expand Up @@ -2255,7 +2272,10 @@ class BoundMethodSignature:
return_override: Optional[Value] = None

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
self,
args: Iterable[Argument],
visitor: "NameCheckVisitor",
node: Optional[ast.AST],
) -> Value:
ret = self.signature.check_call(
[(self.self_composite, None), *args], visitor, node
Expand Down Expand Up @@ -2299,30 +2319,7 @@ def __str__(self) -> str:
return f"{self.signature} bound to {self.self_composite.value}"


@dataclass(frozen=True)
class PropertyArgSpec:
"""Pseudo-argspec for properties."""

obj: object
return_value: Value = AnyValue(AnySource.unannotated)

def check_call(
self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST
) -> Value:
raise TypeError("property object is not callable")

def has_return_value(self) -> bool:
return not isinstance(self.return_value, AnyValue)

def substitute_typevars(self, typevars: TypeVarMap) -> "PropertyArgSpec":
return PropertyArgSpec(
self.obj, self.return_value.substitute_typevars(typevars)
)


MaybeSignature = Union[
None, Signature, BoundMethodSignature, PropertyArgSpec, OverloadedSignature
]
MaybeSignature = Union[None, Signature, BoundMethodSignature, OverloadedSignature]


def make_bound_method(
Expand All @@ -2339,7 +2336,7 @@ def make_bound_method(
return_override = argspec.return_override
return BoundMethodSignature(argspec.signature, self_composite, return_override)
else:
assert False, f"invalid argspec {argspec}"
assert_never(argspec)


T = TypeVar("T")
Expand Down
41 changes: 41 additions & 0 deletions pyanalyze/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def test_property(self):
def capybara(uid):
assert_is_value(PropertyObject(uid).string_property, TypedValue(str))

@assert_passes()
def test_local_return(self):
class X:
@property
def foo(self):
return str(1)

def capybara() -> None:
assert_is_value(X().foo, TypedValue(str))


class TestShadowing(TestNameCheckVisitorBase):
@assert_passes()
Expand Down Expand Up @@ -1108,3 +1118,34 @@ def wrapper():
assert_type(func(1), int)
assert_type(func(1, 1), int)
assert_type(func("x"), float)


class TestSelfAnnotation(TestNameCheckVisitorBase):
@assert_passes()
def test_method(self):
from typing import Generic, TypeVar

T = TypeVar("T")

class Capybara(Generic[T]):
def method(self: "Capybara[int]") -> int:
return 1

def caller(ci: Capybara[int], cs: Capybara[str]):
assert_is_value(ci.method(), TypedValue(int))
cs.method() # E: incompatible_argument

@assert_passes()
def test_property(self):
from typing import Generic, TypeVar

T = TypeVar("T")

class Capybara(Generic[T]):
@property
def prop(self: "Capybara[int]") -> int:
return 1

def caller(ci: Capybara[int], cs: Capybara[str]):
assert_is_value(ci.prop, TypedValue(int))
cs.prop # E: incompatible_argument
2 changes: 0 additions & 2 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,6 @@ def get_signature(
return None
if isinstance(signature, pyanalyze.signature.BoundMethodSignature):
signature = signature.get_signature(ctx=ctx)
if isinstance(signature, pyanalyze.signature.PropertyArgSpec):
return None
return signature

def substitute_typevars(self, typevars: TypeVarMap) -> "Value":
Expand Down