Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN gradient when using masked embeddings -> norm -> max #26328

Open
VRichardJP opened this issue Feb 5, 2025 · 0 comments
Open

NaN gradient when using masked embeddings -> norm -> max #26328

VRichardJP opened this issue Feb 5, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@VRichardJP
Copy link

Description

Here is a small example with a 3 layer model:

  1. nnx.Embed with num_embeddings=E
  2. nnx.BatchNorm
  3. max reduction

This model is used to process a masked [B L] integer array in a train step like function.
The integer array x contains random int values from 0 to 2*E, so roughly half are outside the embedding range. However, it is paired with a boolean mask to ignore the NaN produced by the nnx.Embed layer. The mask is applied to all the subsequent layers, so I assume the NaNs should not matter:

import chex
import jax
import jax.numpy as jnp
import optax
from flax import nnx

B = 8
L = 10
E = 10
F = 4

rngs = nnx.Rngs(0)

x = jax.random.randint(rngs(), (B, L), 0, 2 * E)
mask = x < E
targets = jax.random.randint(rngs(), (B, 1), 0, 2)

print("masked x =\n", jnp.where(mask, x, -1))


class Model(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.emb = nnx.Embed(num_embeddings=E, features=F, rngs=rngs)
        self.norm = nnx.BatchNorm(num_features=F, rngs=rngs)

    def __call__(self, x, *, mask=None):
        mask = jnp.expand_dims(mask, -1)
        # [B L] -> [B L F]
        x = self.emb(x)
        # x = jnp.where(mask, x, 0.0)
        # [B L F] -> [B L F]
        x = self.norm(x, mask=mask)
        # [B L F] -> [B 1]
        x = jnp.max(x, axis=(1, 2), where=mask, initial=-jnp.inf)
        return x


model = Model(rngs=rngs)
print("model(x, mask=mask) =\n", model(x, mask=mask))


def loss_fn(model, inputs, targets):
    x, mask = inputs
    logits = model(x, mask=mask)
    loss = optax.sigmoid_binary_cross_entropy(logits, targets).mean()
    preds = nnx.sigmoid(logits)
    return loss, preds


# @chex.chexify
# @nnx.jit
def step(model, inputs, targets):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, preds), grads = grad_fn(model, inputs, targets)

    jax.debug.print("loss= {}", loss)
    jax.debug.print("preds= {}", preds)
    jax.debug.print("grads=\n{}", grads)

    chex.assert_tree_all_finite((preds, loss, grads))

    return loss, preds, grads


loss, preds, grads = step(model, (x, mask), targets)

print("ok")

The above code fails because the computed grads contains NaN values:

masked x =
 [[ 3  1  6 -1  2  3  8  5 -1 -1]
 [ 0  6 -1 -1  6 -1  1 -1  4 -1]
 [-1  5  1  6  7 -1 -1  1  6 -1]
 [ 0  3  5 -1 -1  4  4  9  9  0]
 [-1 -1 -1  4  6 -1  9  5 -1 -1]
 [ 1 -1 -1 -1 -1  8  6  2  6  2]
 [-1  1 -1  8 -1 -1  6  5  3 -1]
 [-1 -1  2  2  0  2 -1  0 -1 -1]]
model(x, mask=mask) =
 [2.4513273 1.6670858 1.6670858 2.4513273 1.4915867 1.6670858 2.4513273
 0.608973 ]
loss= 0.8531271815299988
preds= [0.92065847 0.8411868  0.8411868  0.92065847 0.81631625 0.8411868
 0.92065847 0.6477065 ]
grads=
State({
  'emb': {
    'embedding': VariableState(
      type=Param,
      value=Array([[nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan]], dtype=float32)
    )
  },
  'norm': {
    'bias': VariableState(
      type=Param,
      value=Array([0.08107008, 0.        , 0.11087191, 0.02675286], dtype=float32)
    ),
    'scale': VariableState(
      type=Param,
      value=Array([nan, nan, nan, nan], dtype=float32)
    )
  }
})

I have observed the following:

  • the problem disappear if the norm layer is removed.
  • adding x = jnp.where(mask, x, 0.0) after the x = self.emb(x) line seems to fix the issue, even though it does not change loss and preds (which is expected since the replaced NaN values are still masked).

Although adding x = jnp.where(mask, x, 0.0) does the trick, I have the feeling it should not be necessary.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.2
python: 3.12.8 (main, Jan  9 2025, 10:33:38) [GCC 14.2.1 20240910]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='lenovo-t14s', release='6.12.7-zen1-1-zen', version='#1 ZEN SMP PREEMPT_DYNAMIC Fri, 27 Dec 2024 14:24:32 +0000', machine='x86_64')
@VRichardJP VRichardJP added the bug Something isn't working label Feb 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant