Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a way to get error diagnostics out of isinstance checks #167

Open
reinerp opened this issue Feb 8, 2024 · 1 comment
Open

Provide a way to get error diagnostics out of isinstance checks #167

reinerp opened this issue Feb 8, 2024 · 1 comment
Labels
feature New feature

Comments

@reinerp
Copy link

reinerp commented Feb 8, 2024

The assert isinstance(...) pattern prints a mostly useless message, just "AssertionError" without explanation. Would it be possible to expose an assertIsInstance(x, ty) API that prints expected versus actual, like we get for errors in the function arguments?

@patrick-kidger patrick-kidger added the feature New feature label Feb 8, 2024
@patrick-kidger
Copy link
Owner

So this actually dovetails well with another feature I would like to add.

Beartype now supports checking for an __instancecheck_str__ method. (Beartype release notes, relevant jaxtyping discussion thread.)

Once this is added, then your use-case could be easily supported via assert isinstance(x, ty), ty.__instancecheck_str__(x).

This shouldn't be too much work to add. Discussing some jaxtyping internals briefly, the plan is basically to rewrite things from

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        if something_bad:
            return False
        ...
        return True

to

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        return cls.__instancecheck_str__(obj) != ""

    def __instancecheck_str__(cls, obj):
        if something_bad:
            return "something bad!" + _exc_shape_info(get_shape_memo())
        ...
        return ""

which would give both specifically how we failed the check (more than we get at the moment under any circumstances!) and all the extra information about the current values of bindings (via _exc_shape_info).

I'd be happy to guide a pull request on this; else I'm hoping to get around to this myself in the near future.

patrick-kidger added a commit that referenced this issue Feb 12, 2024
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing:
```python
@jaxtyped(typechecker=beartype)
def foo(...): ...
```

With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following:

1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc):
    ```python
    @jaxtyped(typechecker=None)
    @beartype
    def foo(...): ...
    ```
    (In practice we probably won't recommend the above combination in the docs just to keep things simple.)

2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?)

3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing
    ```python
    tt = Float[Array, "foo"]
    assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings()
    ```
    which is still a bit long-winded right now but is a step in the right direction.

(CC @leycec for interest.)
patrick-kidger added a commit that referenced this issue Feb 17, 2024
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing:
```python
@jaxtyped(typechecker=beartype)
def foo(...): ...
```

With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following:

1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc):
    ```python
    @jaxtyped(typechecker=None)
    @beartype
    def foo(...): ...
    ```
    (In practice we probably won't recommend the above combination in the docs just to keep things simple.)

2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?)

3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing
    ```python
    tt = Float[Array, "foo"]
    assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings()
    ```
    which is still a bit long-winded right now but is a step in the right direction.

(CC @leycec for interest.)
patrick-kidger added a commit that referenced this issue Feb 17, 2024
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing:
```python
@jaxtyped(typechecker=beartype)
def foo(...): ...
```

With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following:

1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc):
    ```python
    @jaxtyped(typechecker=None)
    @beartype
    def foo(...): ...
    ```
    (In practice we probably won't recommend the above combination in the docs just to keep things simple.)

2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?)

3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing
    ```python
    tt = Float[Array, "foo"]
    assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings()
    ```
    which is still a bit long-winded right now but is a step in the right direction.

(CC @leycec for interest.)
patrick-kidger added a commit that referenced this issue Feb 17, 2024
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing:
```python
@jaxtyped(typechecker=beartype)
def foo(...): ...
```

With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following:

1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc):
    ```python
    @jaxtyped(typechecker=None)
    @beartype
    def foo(...): ...
    ```
    (In practice we probably won't recommend the above combination in the docs just to keep things simple.)

2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?)

3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing
    ```python
    tt = Float[Array, "foo"]
    assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings()
    ```
    which is still a bit long-winded right now but is a step in the right direction.

(CC @leycec for interest.)
patrick-kidger added a commit that referenced this issue Feb 25, 2024
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing:
```python
@jaxtyped(typechecker=beartype)
def foo(...): ...
```

With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following:

1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc):
    ```python
    @jaxtyped(typechecker=None)
    @beartype
    def foo(...): ...
    ```
    (In practice we probably won't recommend the above combination in the docs just to keep things simple.)

2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?)

3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing
    ```python
    tt = Float[Array, "foo"]
    assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings()
    ```
    which is still a bit long-winded right now but is a step in the right direction.

(CC @leycec for interest.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants