Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to type a "DTypeLike" argument and runtime check it ? #165

Open
pchanial opened this issue Feb 5, 2024 · 1 comment
Open

How to type a "DTypeLike" argument and runtime check it ? #165

pchanial opened this issue Feb 5, 2024 · 1 comment
Labels
feature New feature

Comments

@pchanial
Copy link

pchanial commented Feb 5, 2024

When I run the following code through beartype :

from jaxtyping import DTypeLike
import jax.numpy as jnp

class Foo:
    n = 10
    def ones(self, dtype : DTypeLike | None = None) -> Shaped[Array, '...']:
       return jnp.ones(self.n, dtype=dtype)

I get the following error :

E   beartype.roar.BeartypeDecorHintPep3119Exception: Method ...check_return() parameter "dtype" type hint <class 'jax._src.typing.SupportsDType'> uncheckable at runtime (i.e., not passable as second parameter to isinstance(), due to raising "TypeError: Instance and class checks can only be used with @runtime_checkable protocols" from metaclass __instancecheck__() method).

I understand that DTypeLike is imported in jaxtyping from jax.typing, but it there a way to make the above code compliant with runtime checkers using jaxtyping ?

@patrick-kidger
Copy link
Owner

Thanks for the report!

Looks like the underlying jax.typing.DTypeLike needs to add the typing.runtime_checkable decorator. I'd suggest opening an issue (or one-line PR) on the main JAX repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants