Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specializing generic class over ParamSpec #15096

Closed
raphCode opened this issue Apr 21, 2023 · 2 comments
Closed

Specializing generic class over ParamSpec #15096

raphCode opened this issue Apr 21, 2023 · 2 comments
Labels
bug mypy got something wrong topic-paramspec PEP 612, ParamSpec, Concatenate

Comments

@raphCode
Copy link

raphCode commented Apr 21, 2023

Bug Report

I tried to be too clever and wanted to feed the signature of an existing function in a generic class. The goal is to specialize the class for different signatures.*
I invested quite some time trying, but It always failed in funny ways.
I realize that this is probably not intended to work, but I could not find any information in the mypy / python typing docs which explicitly forbids this. I am therefore not sure if this is a mypy or documentation issue.

I want to share my results here. Maybe they are useful as a test scenario or for discovering bugs.

To Reproduce

Common setup, shared with all examples:

from typing import *

P = ParamSpec("P")
R = TypeVar("R", covariant=True)

def forward(a: int, b: str) -> float:
    return 0.0

class Proto(Generic[P, R], Protocol):
    def forward(self, *args: P.args, **kwargs: P.kwargs) -> R:
        pass

    def default_implementation(self, *args: P.args, **kwargs: P.kwargs) -> R:
        return cast(R, ...)

Attempt 1: Custom generic class (playground)

class MetaProto(type):
    def __getitem__(cls, item: Callable[P, R]) -> Type[Proto[P, R]]:
        return Proto[P, R]

class ProtoFactory(metaclass=MetaProto):
    pass

reveal_type(ProtoFactory[forward])  # works

class A(ProtoFactory[forward]):  # fails: error
    pass
note: Revealed type is "Type[Proto[[a: builtins.int, b: builtins.str], builtins.float]]"
error: "ProtoFactory" expects no type arguments, but 1 given  [type-arg]
error: Function "forward" is not valid as a type  [valid-type]

Attempt 2: Class decorator (playground)

C = TypeVar("C")

class MakeProto(Generic[P, R]):
    def __init__(self, a: Callable[P, R]):
        pass

    def __call__(self, wrapped: Type[C]) -> Type[Proto[P, R]]:
        return Proto[P, R]

class A:
    pass
reveal_type(MakeProto(forward)(A))  # works

@MakeProto(forward)  # fails: decorator changes nothing
class B:
    pass
reveal_type(B)
note: Revealed type is "Type[Proto[[a: builtins.int, b: builtins.str], builtins.float]]"
note: Revealed type is "def () -> B"

Expected Behavior

No errors, inheritance from the specialized Protocol.
See also the lines with # works comments.

Additional comments

*I am aware of some alternatives:

  • per-method signature copy decorator
    My current workaround, using a hand-written class specialized for every signature I want to use. I write one method and copy the signature over to the others.
    The problem is, I have to repeat this for every method in every class instead of just specializing the generic class multiple times.
  • Specializing a class generic over TypeVarTuple / ParamSpec with a tuple of argument types:
    This unfortunately does not capture positional argument names

My use case is related to PyTorch: Modules implement a forward() method, but are called using a superclass's __call__(), which does not know anything about the forward()'s signature. I want to write an abstract base class that:

  • enforces a signature for implementors on forward()
  • provides the same signature in __call__() to typecheck the usage outside of the module
  • provides a new method which also shares forward()'s signature

Your Environment

  • Mypy version used: 1.2.0
  • Python version used: 3.11
@raphCode raphCode added the bug mypy got something wrong label Apr 21, 2023
@JelleZijlstra JelleZijlstra added the topic-paramspec PEP 612, ParamSpec, Concatenate label Apr 22, 2023
@raphCode
Copy link
Author

I realized, instead of trying to pass an existing function, I could also pass a Protocol which implements __call__().
That seems to be accepted better by mypy since the error Function "forward" is not valid as a type [valid-type] disappears. Not sure if passing functions directly is also supposed to work?

However, there are still two bugs in mypy which block me from achieving my goal or testing any further:

@raphCode raphCode changed the title Specializing generic Protocol over ParamSpec with existing function signature Specializing generic class over ParamSpec Apr 27, 2023
@erictraut
Copy link

Functions cannot be used as type arguments to a generic class in the current Python static type system, so "Attempt 1" will not work.

"Attempt 2" doesn't look like it will work at runtime, at least not as it's currently implemented, because MakeProto.__call__ is returning a protocol class, which can't be instantiated at runtime. Otherwise it looks like this should work. The problem is that mypy doesn't currently honor class decorators; it simply ignores them. Unless/until that bug is fixed, this approach won't work in mypy. It works in pyright, which does honor class decorators.

I don't see anything else actionable in the bug report, so I think it can be closed.

@hauntsaninja hauntsaninja closed this as not planned Won't fix, can't repro, duplicate, stale Aug 24, 2023
raphCode added a commit to raphCode/master-thesis-muzero that referenced this issue Apr 10, 2024
This change provides the correct function signatures for __call__() and
si() in the network classes. This enables typechecking the uses /
calling locations of the networks in the rest of the code.

It immediately uncovered a bug in PerfectInformationRLPlayer.own_move(),
where the representation network was called with an additional None
argument before the observations. This is a remnant from the time with
beliefs in the representation network. It did not do any harm because
None inputs are filtered out in the fully connected networks.

Previously:
- forward() had the correct signature, but that is never called in the
  code, only overridden by network implementations
- __call__()'s signature came from nn.Module, where it has Any
  for arguments and return value
- si() came from NetworkBase, where I could only customize the return
  type

An ideal solution would be to have base class which is generic over
function signatures, so the code can be reused.
I was not able to get that to work in mypy tho:
python/mypy#15096

So, therefore a lot of duplicate code in each of the abstract base
classes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug mypy got something wrong topic-paramspec PEP 612, ParamSpec, Concatenate
Projects
None yet
Development

No branches or pull requests

4 participants