Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: Failed assertion...expected element type f32 but received si32 #19841

Open
csparker247 opened this issue Feb 16, 2024 · 2 comments
Open
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@csparker247
Copy link

csparker247 commented Feb 16, 2024

Description

I set up a venv for my project using jax-metal, but hit the following assertion error when I ran my otherwise functioning code:

/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1650: failed assertion `Incompatible element type for parameter at index 18, mlir module expected element type f32 but received si32'

The error does not come with a stack trace. Disabling jit for the entire script avoided the issue, but likewise didn't help with stack tracing. On a hunch, I found that the issue was somehow related to a no-op Flax module I have that looks like this:

class Foo(nn.Module):
    """No-op module with stub parameters"""
    @nn.compact
    def __call__(self, x):
        # need some sort of unused param for pytree reasons
        self.param('null', zeros, 0) # changing 0 -> (0,) does not fix anything
                                     # changing 0 -> (1,) does fix the assertion
        return x

As noted in the comment, changing the shape for the unused param does fix the error. However, the assertion is not raised when the module's apply function is called. After stepping through line-by-line with a debugger, I've found that it's raised much later, during the teardown of my jit'd training step function.

Here is a schematic example of my code organization. This code does not reproduce the error, but is intended to show where the assertion gets raised in relation to the module's apply function:

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.linen.initializers import zeros
from flax.training.train_state import TrainState
from jax import random


class Foo(nn.Module):
    """No-op module with stub parameters"""

    @nn.compact
    def __call__(self, x):
        self.param('null', zeros, 0)  # changing 0 -> (1,) does fix the assertion
        return x


@jax.jit
def fwd(p, s, x):
    """Model forward function"""
    y = s.apply_fn(p, x)  # apply the problematic module
    
    y = ...  # apply multiple other modules and functions
    
    l = jnp.mean((y - x) ** 2)
    return l, {'val': y, 'loss': l}


@jax.jit
def step(ss, xx):
    """Training step helper"""
    (l, m), grads = jax.value_and_grad(fwd, has_aux=True)(ss.params, ss, xx)

    new_state = ss.apply_gradients(grads=grads)

    return new_state, l, m  # assertion is raised here in original code


# initialize the model
rng = random.PRNGKey(0)
foo = Foo()
params = jax.jit(foo.init)(rng, jnp.ones((1,)))
state = TrainState.create(
    apply_fn=jax.jit(foo.apply),
    params=params,
    tx=optax.sgd(0.01)
)

# run for some steps
for _ in range(10):
    state, loss, metrics = step(state, jnp.ones((10, 10)))

Anyway, I'm not sure if this is a bug or intended behavior for Metal. I have a fix, but I would like to understand why that parameter can't be a zero-sized array in my project code when it does work in this example code. Or maybe it's not related to that module? But then, why does changing that module fix the error? It seems similar to #16435, but that also wasn't a zero-sized array issue, so I don't know.

System info (python version, jaxlib version, accelerator, etc.)

Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
jax:    0.4.20
jaxlib: 0.4.20
jax-metal : 0.0.5
numpy:  1.26.4
python: 3.10.13 (main, Aug 24 2023, 12:59:26) [Clang 15.0.0 (clang-1500.0.40.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
@csparker247 csparker247 added the bug Something isn't working label Feb 16, 2024
@csparker247 csparker247 changed the title Failed assertion of unknown source Failed assertion...expected element type f32 but received si32 with jax-metal Feb 16, 2024
@jakevdp jakevdp changed the title Failed assertion...expected element type f32 but received si32 with jax-metal jax-metal: Failed assertion...expected element type f32 but received si32 Feb 16, 2024
@shuhand0
Copy link
Collaborator

shuhand0 commented Mar 12, 2024

The issue is not reproducible. Do you still see the same problem with the latest OS 14.4 and jax-metal 0.0.6?

@csparker247
Copy link
Author

I just updated to macOS 14.4 and jax-metal 0.0.6, and the issue does still occur if I pass 0 or (0,) to my module initializer, but not if I pass (1,).

As I say, the above code is a schematic that does not reproduce the error, and I'm unfortunately not in a position at the moment to start trimming my full code base down to a minimal reproducible example. It'll probably be at least a month or so before I'll have that sort of time.

I understand if you want to close the issue since it's not reproducible. I just wanted to make sure this occurrence was at least documented in case someone else has a similar issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants