You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I set up a venv for my project using jax-metal, but hit the following assertion error when I ran my otherwise functioning code:
/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1650: failed assertion `Incompatible element type for parameter at index 18, mlir module expected element type f32 but received si32'
The error does not come with a stack trace. Disabling jit for the entire script avoided the issue, but likewise didn't help with stack tracing. On a hunch, I found that the issue was somehow related to a no-op Flax module I have that looks like this:
classFoo(nn.Module):
"""No-op module with stub parameters"""@nn.compactdef__call__(self, x):
# need some sort of unused param for pytree reasonsself.param('null', zeros, 0) # changing 0 -> (0,) does not fix anything# changing 0 -> (1,) does fix the assertionreturnx
As noted in the comment, changing the shape for the unused param does fix the error. However, the assertion is not raised when the module's apply function is called. After stepping through line-by-line with a debugger, I've found that it's raised much later, during the teardown of my jit'd training step function.
Here is a schematic example of my code organization. This code does not reproduce the error, but is intended to show where the assertion gets raised in relation to the module's apply function:
importflax.linenasnnimportjaximportjax.numpyasjnpimportoptaxfromflax.linen.initializersimportzerosfromflax.training.train_stateimportTrainStatefromjaximportrandomclassFoo(nn.Module):
"""No-op module with stub parameters"""@nn.compactdef__call__(self, x):
self.param('null', zeros, 0) # changing 0 -> (1,) does fix the assertionreturnx@jax.jitdeffwd(p, s, x):
"""Model forward function"""y=s.apply_fn(p, x) # apply the problematic moduley= ... # apply multiple other modules and functionsl=jnp.mean((y-x) **2)
returnl, {'val': y, 'loss': l}
@jax.jitdefstep(ss, xx):
"""Training step helper"""
(l, m), grads=jax.value_and_grad(fwd, has_aux=True)(ss.params, ss, xx)
new_state=ss.apply_gradients(grads=grads)
returnnew_state, l, m# assertion is raised here in original code# initialize the modelrng=random.PRNGKey(0)
foo=Foo()
params=jax.jit(foo.init)(rng, jnp.ones((1,)))
state=TrainState.create(
apply_fn=jax.jit(foo.apply),
params=params,
tx=optax.sgd(0.01)
)
# run for some stepsfor_inrange(10):
state, loss, metrics=step(state, jnp.ones((10, 10)))
Anyway, I'm not sure if this is a bug or intended behavior for Metal. I have a fix, but I would like to understand why that parameter can't be a zero-sized array in my project code when it does work in this example code. Or maybe it's not related to that module? But then, why does changing that module fix the error? It seems similar to #16435, but that also wasn't a zero-sized array issue, so I don't know.
System info (python version, jaxlib version, accelerator, etc.)
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
jax: 0.4.20
jaxlib: 0.4.20
jax-metal : 0.0.5
numpy: 1.26.4
python: 3.10.13 (main, Aug 24 2023, 12:59:26) [Clang 15.0.0 (clang-1500.0.40.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
The text was updated successfully, but these errors were encountered:
csparker247
changed the title
Failed assertion of unknown source
Failed assertion...expected element type f32 but received si32 with jax-metal
Feb 16, 2024
jakevdp
changed the title
Failed assertion...expected element type f32 but received si32 with jax-metal
jax-metal: Failed assertion...expected element type f32 but received si32
Feb 16, 2024
I just updated to macOS 14.4 and jax-metal 0.0.6, and the issue does still occur if I pass 0 or (0,) to my module initializer, but not if I pass (1,).
As I say, the above code is a schematic that does not reproduce the error, and I'm unfortunately not in a position at the moment to start trimming my full code base down to a minimal reproducible example. It'll probably be at least a month or so before I'll have that sort of time.
I understand if you want to close the issue since it's not reproducible. I just wanted to make sure this occurrence was at least documented in case someone else has a similar issue.
Description
I set up a venv for my project using
jax-metal
, but hit the following assertion error when I ran my otherwise functioning code:The error does not come with a stack trace. Disabling jit for the entire script avoided the issue, but likewise didn't help with stack tracing. On a hunch, I found that the issue was somehow related to a no-op Flax module I have that looks like this:
As noted in the comment, changing the shape for the unused param does fix the error. However, the assertion is not raised when the module's apply function is called. After stepping through line-by-line with a debugger, I've found that it's raised much later, during the teardown of my jit'd training step function.
Here is a schematic example of my code organization. This code does not reproduce the error, but is intended to show where the assertion gets raised in relation to the module's apply function:
Anyway, I'm not sure if this is a bug or intended behavior for Metal. I have a fix, but I would like to understand why that parameter can't be a zero-sized array in my project code when it does work in this example code. Or maybe it's not related to that module? But then, why does changing that module fix the error? It seems similar to #16435, but that also wasn't a zero-sized array issue, so I don't know.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: