You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This model is used to process a masked [B L] integer array in a train step like function.
The integer array x contains random int values from 0 to 2*E, so roughly half are outside the embedding range. However, it is paired with a boolean mask to ignore the NaN produced by the nnx.Embed layer. The mask is applied to all the subsequent layers, so I assume the NaNs should not matter:
the problem disappear if the norm layer is removed.
adding x = jnp.where(mask, x, 0.0) after the x = self.emb(x) line seems to fix the issue, even though it does not change loss and preds (which is expected since the replaced NaN values are still masked).
Although adding x = jnp.where(mask, x, 0.0) does the trick, I have the feeling it should not be necessary.
System info (python version, jaxlib version, accelerator, etc.)
Description
Here is a small example with a 3 layer model:
nnx.Embed
withnum_embeddings=E
nnx.BatchNorm
This model is used to process a masked
[B L]
integer array in a train step like function.The integer array
x
contains random int values from 0 to 2*E, so roughly half are outside the embedding range. However, it is paired with a boolean mask to ignore the NaN produced by thennx.Embed
layer. The mask is applied to all the subsequent layers, so I assume the NaNs should not matter:The above code fails because the computed grads contains NaN values:
I have observed the following:
x = jnp.where(mask, x, 0.0)
after thex = self.emb(x)
line seems to fix the issue, even though it does not changeloss
andpreds
(which is expected since the replaced NaN values are still masked).Although adding
x = jnp.where(mask, x, 0.0)
does the trick, I have the feeling it should not be necessary.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: