-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change the type hint for nn.Module.__call__ to be friendly to overrides. #74746
Comments
Actually, using |
There might be some difference in the type inference mechanism? Just for reference I also discussed this issue in the pyright repo: microsoft/pyright#3249 (comment) |
For
# main.py
import torch
from torch import nn
class LinWrapper(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = nn.Linear(12,2)
def forward(
self, x: torch.Tensor
) -> torch.Tensor:
out = self.lin(x)
reveal_type(out)
return out Bash command to run mypy main.py Output:
Note when tested with custom module (with type hints), it reveals to be |
Can someone pr the proposed change so we can see how it cis |
One possible way to go is to copy the type signature of forward method if there is any type annotation of def get_type_signature(func):
# if func has type annotation
# return the annotation
# otherwise, return Callable[..., Any]
pass
forward_type_sign = get_type_signature(self.forward)
__call__ : forward_type_sign = _call_impl But I am not sure how to do this in a reliable way. |
any updates to this? if not, is there any simple workaround? |
I have tried many different things. Here is the cleanest workaround I have: _P = ParamSpec("_P")
_R = TypeVar("_R")
def _returns_nn_module_call(*args):
return nn.Module.__call__
def patch_call(src_func: Callable[_P, _R]) -> Callable[..., Callable[_P, _R]]:
return _returns_nn_module_call
class MyModule(nn.Module):
def forward(self, x: Tensor, y: Tensor) -> Tensor:
return x + y
@patch_call(forward)
def __call__(self): ... Defining a function would override the inherited type annotation. Here, I use the decorator to replace the call signature with that of the def patch_call(src_func: Callable[_P, _R]):
def decorator(target_fn) -> Callable[_P, _R]:
def inner(self, *args, **kwargs):
return super(self.__class__, self).__call__(*args, **kwargs)
return inner
return decorator with the cost of one additional stack depth at every call. |
🚀 The feature, motivation and pitch
Currently,
nn.Module.__call__
has type hints defined using__call__ : Callable[..., Any] = _call_impl
. However, this declared type makes it difficult for the user to override the type hint ofnn.Module.__call__
using inferred types. For example:If you paste the above code in VSCode, you can see that the type for z1 is
Any
(wrong) while that forz2
is Tensor (correct).Alternatives
Instead, if we remove the type annotations in
nn.Module
and do the following instead, the above example would work as expected.Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @ezyang @malfet @rgommers @xuzhao9 @gramster
The text was updated successfully, but these errors were encountered: