Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent shape checking for lists of tensors #234

Open
sean-roelofs-ai opened this issue Jul 23, 2024 · 2 comments
Open

Inconsistent shape checking for lists of tensors #234

sean-roelofs-ai opened this issue Jul 23, 2024 · 2 comments
Labels
question User queries

Comments

@sean-roelofs-ai
Copy link

sean-roelofs-ai commented Jul 23, 2024

I am trying to type check the shape of Tensors inside a list. It seems like the shape wildcards are not enforced consistently though the list. I wrote an example to help describe the issue.

from beartype import beartype
from beartype.typing import List, Optional
from jaxtyping import Float, jaxtyped
import torch
from torch import Tensor
import unittest


@jaxtyped(typechecker=beartype)
def foo(x: Float[Tensor, "B C"], feat_list: List[Float[Tensor, "B C N"]], y: Optional[Float[Tensor, "N"]] = None):
    pass
    
class TestJaxtyping(unittest.TestCase):
    def test_variable_list(self):
        B = 8
        C = 32
        N_1 = 3
        N_2 = 4
        x = torch.zeros((B, C)) 
        feats_a = [torch.zeros((B, C, N_1)), torch.zeros((B, C, N_2))]
        feats_b = [torch.zeros((B, C, N_1)), torch.zeros((C, B, N_2))]
        y = torch.zeros(N_1)

        # I feel like this should error, but it doesn't
        foo(x, feats_a)
        
        try:
            foo(x, feats_b)
        except:
            print("Jaxtyping caught the B C switch here")

        # most concerning of all, this works sometimes and not others depending on if N matches to N_1 or N_2
        foo(x, feats_a, y)


        
if __name__ == "__main__":
    unittest.main()

Is there any way to enforce tensor shapes inside a list correctly?

@patrick-kidger
Copy link
Owner

I think this is a beartype thing: it only checks one element of a list.

CC @leycec

@leycec
Copy link

leycec commented Jul 25, 2024

...heh. @beartype woes, huh? It's all true. @beartype guarantees constant-time O(1) complexity by only pseudo-randomly type-checking one item of each list. This is both a bad thing and a good thing. On the bright side, @beartype scales to arbitrarily large lists (and all other kinds of containers); @beartype is guaranteed to never Denial-of-Service (DoS) your workflow when a disturbingly large list (or other kind of container) inevitably gets passed in. On the dark side, non-deterministic type-checking kinda sucks. I get that and sympathize with your pain.

You are now thinking: "I hate @beartype." You're not wrong, @seanroelofs. But... fear not! An upcoming release of @beartype will provide the sort of linear-time O(n) type-checking you want and need. If your use case can't wait until then, no judgement, bro typeguard is a valid alternative to @beartype that might be a better fit here in the meantime: e.g.,

from typeguard import typechecked  # <-- this fills me with sadness
from beartype.typing import List, Optional
from jaxtyping import Float, jaxtyped
import torch
from torch import Tensor
import unittest

@jaxtyped(typechecker=typechecked)  # <-- sadness intensifies
def foo(x: Float[Tensor, "B C"], feat_list: List[Float[Tensor, "B C N"]], y: Optional[Float[Tensor, "N"]] = None):
    pass

Of course, typeguard comes with the opposite tradeoff. It type-checks everything and thus fails to scale to large problem domains. But... maybe that's not a problem here?

And thanks so much to @patrick-kidger for pinging me on. Hope you're having an amazing summer! It's been so long since we've GitHub chatted. How about that CrowdStrike fiasco, huh? Yikes. Oh – and this issue can (probably) be safely closed. As everyone has surmised, this is almost certainly @beartype's fault. jaxtyping is blameless in this and all things.

@patrick-kidger patrick-kidger added the question User queries label Aug 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants