Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NoReturnGuard #370

Merged
merged 11 commits into from
Dec 26, 2021
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add `pyanalyze.extensions.NoReturnGuard` (#370)
- Infer call signatures for `Type[X]` (#369)
- Support configuration in a `pyproject.toml` file (#368)
- Require `typeshed_client` 2.0 (#361)
Expand Down
6 changes: 6 additions & 0 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
CustomCheck,
ExternalType,
HasAttrGuard,
NoReturnGuard,
ParameterTypeGuard,
TypeGuard,
)
Expand All @@ -71,6 +72,7 @@
KnownValue,
MultiValuedValue,
NO_RETURN_VALUE,
NoReturnGuardExtension,
ParameterTypeGuardExtension,
TypeGuardExtension,
TypedValue,
Expand Down Expand Up @@ -1068,6 +1070,10 @@ def _value_from_metadata(entry: Value, ctx: Context) -> Union[Value, Extension]:
return ParameterTypeGuardExtension(
entry.val.varname, _type_from_runtime(entry.val.guarded_type, ctx)
)
elif isinstance(entry.val, NoReturnGuard):
return NoReturnGuardExtension(
entry.val.varname, _type_from_runtime(entry.val.guarded_type, ctx)
)
elif isinstance(entry.val, HasAttrGuard):
return HasAttrGuardExtension(
entry.val.varname,
Expand Down
29 changes: 23 additions & 6 deletions pyanalyze/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,20 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise TypeError(f"{self} is not callable")


class _ParameterTypeGuardMeta(type):
def __getitem__(self, params: Tuple[str, object]) -> "ParameterTypeGuard":
class _ParameterGuardMeta(type):
def __getitem__(self, params: Tuple[str, object]) -> Any:
if not isinstance(params, tuple) or len(params) != 2:
raise TypeError(
"ParameterTypeGuard[...] should be instantiated "
f"{self.__name__}[...] should be instantiated "
"with two arguments (a variable name and a type)."
)
if not isinstance(params[0], str):
raise TypeError("The first argument to ParameterTypeGuard must be a string")
return ParameterTypeGuard(params[0], params[1])
raise TypeError(f"The first argument to {self.__name__} must be a string")
return self(params[0], params[1])


@dataclass(frozen=True)
class ParameterTypeGuard(metaclass=_ParameterTypeGuardMeta):
class ParameterTypeGuard(metaclass=_ParameterGuardMeta):
"""A guard on an arbitrary parameter. Used with ``Annotated``.

Example usage::
Expand All @@ -232,6 +232,23 @@ def is_int(arg: object) -> Annotated[bool, ParameterTypeGuard["arg", int]]:
guarded_type: object


@dataclass(frozen=True)
class NoReturnGuard(metaclass=_ParameterGuardMeta):
"""A no-return guard on an arbitrary parameter. Used with ``Annotated``.

If the function returns, then the condition is true.

Example usage::

def assert_is_int(arg: object) -> Annotated[bool, NoReturnGuard["arg", int]]:
assert isinstance(arg, int)

"""

varname: str
guarded_type: object


class _HasAttrGuardMeta(type):
def __getitem__(self, params: Tuple[str, str, object]) -> "HasAttrGuard":
if not isinstance(params, tuple) or len(params) != 3:
Expand Down
2 changes: 1 addition & 1 deletion pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def _set_add_impl(ctx: CallContext) -> ImplReturn:

def _remove_annotated(val: Value) -> Value:
if isinstance(val, AnnotatedValue):
return val.value
return _remove_annotated(val.value)
elif isinstance(val, MultiValuedValue):
return unite_values(*[_remove_annotated(subval) for subval in val.vals])
return val
Expand Down
38 changes: 18 additions & 20 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
ANY_SIGNATURE,
BoundMethodSignature,
ConcreteSignature,
ImplReturn,
MaybeSignature,
OverloadedSignature,
Signature,
Expand Down Expand Up @@ -145,6 +144,7 @@
KnownValueWithTypeVars,
UNINITIALIZED_VALUE,
NO_RETURN_VALUE,
NoReturnConstraintExtension,
kv_pairs_from_mapping,
make_weak,
unannotate_value,
Expand Down Expand Up @@ -1557,6 +1557,10 @@ def visit_FunctionDef(
and not info.is_overload
and expected_return_value is not None
and expected_return_value != KnownNone
and not (
isinstance(expected_return_value, AnnotatedValue)
and expected_return_value.value == KnownNone
)
and not any(
decorator == KnownValue(abstractmethod)
for _, decorator in info.decorators
Expand Down Expand Up @@ -4373,15 +4377,15 @@ def _check_call_no_mvv(
if extended_argspec is ANY_SIGNATURE:
# don't bother calling it
extended_argspec = None
impl_ret = ImplReturn(AnyValue(AnySource.from_another))
return_value = AnyValue(AnySource.from_another)

elif extended_argspec is None:
self._show_error_if_checking(
node,
f"{callee_wrapped} is not callable",
error_code=ErrorCode.not_callable,
)
impl_ret = ImplReturn(AnyValue(AnySource.error))
return_value = AnyValue(AnySource.error)

else:
arguments = [
Expand All @@ -4394,16 +4398,16 @@ def _check_call_no_mvv(
for keyword, value in keywords
]
if self._is_checking():
impl_ret = extended_argspec.check_call(arguments, self, node)
return_value = extended_argspec.check_call(arguments, self, node)
else:
with self.catch_errors():
impl_ret = extended_argspec.check_call(arguments, self, node)

return_value = impl_ret.return_value
constraint = impl_ret.constraint
return_value = extended_argspec.check_call(arguments, self, node)

if impl_ret.no_return_unless is not NULL_CONSTRAINT:
self.add_constraint(node, impl_ret.no_return_unless)
return_value, nru_extensions = unannotate_value(
return_value, NoReturnConstraintExtension
)
for extension in nru_extensions:
self.add_constraint(node, extension.constraint)

if (
extended_argspec is not None
Expand Down Expand Up @@ -4441,22 +4445,16 @@ def _check_call_no_mvv(
callee_wrapped.val
):
async_fn = callee_wrapped.val.__self__
return AsyncTaskIncompleteValue(
_get_task_cls(async_fn),
annotate_with_constraint(return_value, constraint),
)
return AsyncTaskIncompleteValue(_get_task_cls(async_fn), return_value)
elif isinstance(
callee_wrapped, UnboundMethodValue
) and callee_wrapped.secondary_attr_name in ("async", "asynq"):
async_fn = callee_wrapped.get_method()
return AsyncTaskIncompleteValue(
_get_task_cls(async_fn),
annotate_with_constraint(return_value, constraint),
)
return AsyncTaskIncompleteValue(_get_task_cls(async_fn), return_value)
elif isinstance(callee_wrapped, UnboundMethodValue) and asynq.is_pure_async_fn(
callee_wrapped.get_method()
):
return annotate_with_constraint(return_value, constraint)
return return_value
else:
if (
isinstance(return_value, AnyValue)
Expand All @@ -4466,7 +4464,7 @@ def _check_call_no_mvv(
task_cls = _get_task_cls(callee_wrapped.val)
if isinstance(task_cls, type):
return TypedValue(task_cls)
return annotate_with_constraint(return_value, constraint)
return return_value

def signature_from_value(
self, value: Value, node: Optional[ast.AST] = None
Expand Down
Loading