You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to type check the shape of Tensors inside a list. It seems like the shape wildcards are not enforced consistently though the list. I wrote an example to help describe the issue.
from beartype import beartype
from beartype.typing import List, Optional
from jaxtyping import Float, jaxtyped
import torch
from torch import Tensor
import unittest
@jaxtyped(typechecker=beartype)
def foo(x: Float[Tensor, "B C"], feat_list: List[Float[Tensor, "B C N"]], y: Optional[Float[Tensor, "N"]] = None):
pass
class TestJaxtyping(unittest.TestCase):
def test_variable_list(self):
B = 8
C = 32
N_1 = 3
N_2 = 4
x = torch.zeros((B, C))
feats_a = [torch.zeros((B, C, N_1)), torch.zeros((B, C, N_2))]
feats_b = [torch.zeros((B, C, N_1)), torch.zeros((C, B, N_2))]
y = torch.zeros(N_1)
# I feel like this should error, but it doesn't
foo(x, feats_a)
try:
foo(x, feats_b)
except:
print("Jaxtyping caught the B C switch here")
# most concerning of all, this works sometimes and not others depending on if N matches to N_1 or N_2
foo(x, feats_a, y)
if __name__ == "__main__":
unittest.main()
Is there any way to enforce tensor shapes inside a list correctly?
The text was updated successfully, but these errors were encountered:
...heh.@beartype woes, huh? It's all true. @beartype guarantees constant-time O(1) complexity by only pseudo-randomly type-checking one item of each list. This is both a bad thing and a good thing. On the bright side, @beartype scales to arbitrarily large lists (and all other kinds of containers); @beartype is guaranteed to never Denial-of-Service (DoS) your workflow when a disturbingly large list (or other kind of container) inevitably gets passed in. On the dark side, non-deterministic type-checking kinda sucks. I get that and sympathize with your pain.
You are now thinking: "I hate@beartype." You're not wrong, @seanroelofs. But... fear not! An upcoming release of @beartype will provide the sort of linear-time O(n) type-checking you want and need. If your use case can't wait until then, no judgement, brotypeguard is a valid alternative to @beartype that might be a better fit here in the meantime: e.g.,
fromtypeguardimporttypechecked# <-- this fills me with sadnessfrombeartype.typingimportList, OptionalfromjaxtypingimportFloat, jaxtypedimporttorchfromtorchimportTensorimportunittest@jaxtyped(typechecker=typechecked) # <-- sadness intensifiesdeffoo(x: Float[Tensor, "B C"], feat_list: List[Float[Tensor, "B C N"]], y: Optional[Float[Tensor, "N"]] =None):
pass
Of course, typeguard comes with the opposite tradeoff. It type-checks everything and thus fails to scale to large problem domains. But... maybe that's not a problem here?
And thanks so much to @patrick-kidger for pinging me on. Hope you're having an amazing summer! It's been so long since we've GitHub chatted. How about that CrowdStrike fiasco, huh? Yikes. Oh – and this issue can (probably) be safely closed. As everyone has surmised, this is almost certainly @beartype's fault. jaxtyping is blameless in this and all things.
I am trying to type check the shape of
Tensor
s inside a list. It seems like the shape wildcards are not enforced consistently though the list. I wrote an example to help describe the issue.Is there any way to enforce tensor shapes inside a list correctly?
The text was updated successfully, but these errors were encountered: