Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful Equinox Module: how to annotate? #253

Open
EtaoinWu opened this issue Oct 3, 2024 · 2 comments
Open

Stateful Equinox Module: how to annotate? #253

EtaoinWu opened this issue Oct 3, 2024 · 2 comments
Labels
question User queries

Comments

@EtaoinWu
Copy link

EtaoinWu commented Oct 3, 2024

I recently come up with the following code:

from typing import Self

import equinox as eqx
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import Array, Float, jaxtyped


@jaxtyped(typechecker=beartype) # to typecheck __init__
@beartype
class Accumulator(eqx.Module):
    x: Float[Array, " n"]

    @jaxtyped
    def add(self, y: Float[Array, " n"]) -> Self:
        return self.__class__(self.x + y)

Now, when running this code, jaxtyped complained in a UserWarning saying that it prefers the @jaxtyped(typechecker=beartype) syntax. (This warning was added before beartype's __instancecheck_str__ pseudostandard was implemented.) However, in this context, such syntax will lead to an error by beartype, because it lacks the context to figure out what typing.Self refers to. Therefore the code above is the only way to get it running.

However, this Accumulator faces an issue: If you write

@jaxtyped(typechecker=beartype)
def test_accumulator():
    x = jnp.ones(3)
    y = jnp.ones(4)
    acc1 = Accumulator(x)
    acc1 = acc1.add(x)
    acc2 = Accumulator(y)
    acc2 = acc2.add(y)
    return acc1, acc2

In calling acc2.add(y), it seems that n=3 is still in the memo from the previous acc1.add(x) call, and a type check error BeartypeCallHintParamViolation will be raised.

So, my question is: how do one properly type-annotate this kind of class?

@patrick-kidger
Copy link
Owner

So the issue here is actually Accumulator.__init__. When you have the lone @beartype then this is adding type-checks to the __init__ method, and these are what are producing the n=3 binding.

Unfortunately, jaxtyping+beartype+Self just isn't really a supported combination right now.

@patrick-kidger patrick-kidger added the question User queries label Oct 3, 2024
@EtaoinWu
Copy link
Author

EtaoinWu commented Oct 3, 2024

So the issue here is actually Accumulator.__init__.

Interestingly it is not. If we comment out the line with acc1 = acc1.add(x),

# @jaxtyped(typechecker=beartype) # With or without this line
@beartype
class Accumulator(eqx.Module):
    ...

@jaxtyped(typechecker=beartype)
def test_accumulator():
    x = jnp.ones(3)
    y = jnp.ones(4)
    acc1 = Accumulator(x)
    # acc1 = acc1.add(x)
    acc2 = Accumulator(y)
    acc2 = acc2.add(y)
    return acc1, acc2

test_accumulator()

This actually doesn't raise any error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants